未验证 提交 74266762 编写于 作者: Z Zhang Zheng 提交者: GitHub

Cinn error refactor (#55544)

* Refactor the error message system

* fix header

* fix compile
上级 12fb18dd
......@@ -177,7 +177,7 @@ void TestSplitThrow() {
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(
mod_expr, -1, false, ir::ScheduleErrorMessageLevel::kGeneral);
mod_expr, -1, false, utils::ErrorMessageLevel::kGeneral);
auto fused = ir_sch.Fuse("B", {0, 1});
// statement that cause the exception
auto splited = ir_sch.Split(fused, {-1, -1});
......@@ -196,7 +196,7 @@ void TestSplitThrow() {
auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl);
}
TEST(IrSchedule, split_throw) {
ASSERT_THROW(TestSplitThrow(), ir::enforce::EnforceNotMet);
ASSERT_THROW(TestSplitThrow(), utils::enforce::EnforceNotMet);
}
TEST(IrSchedule, reorder1) {
......
......@@ -41,7 +41,7 @@
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/string.h"
DECLARE_int32(cinn_schedule_error_message_level);
DECLARE_int32(cinn_error_message_level);
namespace cinn {
namespace ir {
......@@ -54,12 +54,11 @@ class ScheduleImpl {
ScheduleImpl() = default;
explicit ScheduleImpl(const ModuleExpr& module_expr,
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral)
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral)
: module_expr_(module_expr), debug_flag_(debug_flag) {
err_msg_level_ = static_cast<ScheduleErrorMessageLevel>(
FLAGS_cinn_schedule_error_message_level ||
static_cast<int>(err_msg_level));
err_msg_level_ = static_cast<utils::ErrorMessageLevel>(
FLAGS_cinn_error_message_level || static_cast<int>(err_msg_level));
}
explicit ScheduleImpl(ModuleExpr&& module_expr)
: module_expr_(std::move(module_expr)) {}
......@@ -138,8 +137,7 @@ class ScheduleImpl {
ModuleExpr module_expr_;
bool debug_flag_{false};
ScheduleErrorMessageLevel err_msg_level_ =
ScheduleErrorMessageLevel::kGeneral;
utils::ErrorMessageLevel err_msg_level_ = utils::ErrorMessageLevel::kGeneral;
};
/** \brief A macro that guards the beginning of each implementation of schedule
......@@ -152,10 +150,10 @@ class ScheduleImpl {
* @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message
* printing
*/
#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \
} \
catch (const IRScheduleErrorHandler& err_hanlder) { \
CINN_THROW(err_hanlder.FormatErrorMessage(primitive, err_msg_level)); \
#define CINN_IR_SCHEDULE_END(err_msg_level) \
} \
catch (const utils::ErrorHandler& err_hanlder) { \
CINN_THROW(err_hanlder.FormatErrorMessage(err_msg_level)); \
}
std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
......@@ -177,7 +175,7 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
std::vector<int> processed_factors;
CINN_IR_SCHEDULE_BEGIN();
processed_factors = ValidateFactors(factors, tot_extent, this->module_expr_);
CINN_IR_SCHEDULE_END("split", this->err_msg_level_);
CINN_IR_SCHEDULE_END(this->err_msg_level_);
int prod_size = std::accumulate(processed_factors.begin(),
processed_factors.end(),
1,
......@@ -2316,7 +2314,7 @@ IRSchedule::IRSchedule() {}
IRSchedule::IRSchedule(const ModuleExpr& module_expr,
utils::LinearRandomEngine::StateType rand_seed,
bool debug_flag,
ScheduleErrorMessageLevel err_msg_level) {
utils::ErrorMessageLevel err_msg_level) {
impl_ =
std::make_unique<ScheduleImpl>(module_expr, debug_flag, err_msg_level);
this->InitSeed(rand_seed);
......
......@@ -24,25 +24,12 @@
#include "paddle/cinn/ir/schedule/schedule_desc.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/utils/error.h"
#include "paddle/cinn/utils/random_engine.h"
namespace cinn {
namespace ir {
/**
* \brief Indicates the level of printing error message in the current Schedule
*/
enum class ScheduleErrorMessageLevel : int32_t {
/** \brief Print an error message in short mode.
* Short mode shows which and where the schedule error happens*/
kGeneral = 0,
/** \brief Print an error message in detailed mode.
* Detailed mode shows which and where the schedule error happens, and the
* schedule input parameters.
*/
kDetailed = 1,
};
/**
* A struct representing a module that contains Expr. This struct is only used
* in Schedule process.
......@@ -85,8 +72,8 @@ class IRSchedule {
explicit IRSchedule(const ModuleExpr& modexpr,
utils::LinearRandomEngine::StateType rand_seed = -1,
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral);
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral);
IRSchedule(ir::ModuleExpr&& mod_expr,
ScheduleDesc&& trace,
utils::LinearRandomEngine::StateType rand_seed = -1);
......
......@@ -20,7 +20,11 @@ namespace cinn {
namespace ir {
std::string IRScheduleErrorHandler::GeneralErrorMessage() const {
return this->err_msg_;
std::ostringstream os;
os << "[IRScheduleError] An error occurred in the scheduel primitive < "
<< this->primitive_ << " >. " << std::endl;
os << this->err_msg_;
return os.str();
}
std::string IRScheduleErrorHandler::DetailedErrorMessage() const {
......@@ -31,44 +35,5 @@ std::string IRScheduleErrorHandler::DetailedErrorMessage() const {
return os.str();
}
std::string IRScheduleErrorHandler::FormatErrorMessage(
const std::string& primitive,
const ScheduleErrorMessageLevel& err_msg_level) const {
std::ostringstream os;
std::string err_msg = err_msg_level == ScheduleErrorMessageLevel::kDetailed
? DetailedErrorMessage()
: GeneralErrorMessage();
os << "[IRScheduleError] An error occurred in the scheduel primitive <"
<< primitive << ">. " << std::endl;
os << "[Error info] " << err_msg;
return os.str();
}
std::string NegativeFactorErrorMessage(const int64_t& factor,
const size_t& idx) {
std::ostringstream os;
os << "The params in factors of Split should be positive. However, the "
"factor at position "
<< idx << " is " << factor << std::endl;
return os.str();
}
std::string InferFactorErrorMessage() {
std::ostringstream os;
os << "The params in factors of Split should not be less than -1 or have "
"more than one -1!"
<< std::endl;
return os.str();
}
std::string FactorProductErrorMessage() {
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal "
"to original loop's extent!"
<< std::endl;
return os.str();
}
} // namespace ir
} // namespace cinn
......@@ -14,130 +14,25 @@
#pragma once
#ifdef __GNUC__
#include <cxxabi.h> // for __cxa_demangle
#endif // __GNUC__
#if !defined(_WIN32)
#include <dlfcn.h> // dladdr
#include <unistd.h> // sleep, usleep
#else // _WIN32
#ifndef NOMINMAX
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#endif
#include <windows.h> // GetModuleFileName, Sleep
#endif
#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
#include <execinfo.h>
#endif
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace ir {
namespace enforce {
#ifdef __GNUC__
inline std::string demangle(std::string name) {
int status = -4; // some arbitrary value to eliminate the compiler warning
std::unique_ptr<char, void (*)(void*)> res{
abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free};
return (status == 0) ? res.get() : name;
}
#else
inline std::string demangle(std::string name) { return name; }
#endif
static std::string GetErrorSumaryString(const std::string& what,
const char* file,
int line) {
std::ostringstream sout;
sout << "\n----------------------\nError Message "
"Summary:\n----------------------\n";
sout << what << "(at " << file << " : " << line << ")" << std::endl;
return sout.str();
}
static std::string GetCurrentTraceBackString() {
std::ostringstream sout;
sout << "\n\n--------------------------------------\n";
sout << "C++ Traceback (most recent call last):";
sout << "\n--------------------------------------\n";
#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
static constexpr int TRACE_STACK_LIMIT = 100;
void* call_stack[TRACE_STACK_LIMIT];
auto size = backtrace(call_stack, TRACE_STACK_LIMIT);
auto symbols = backtrace_symbols(call_stack, size);
Dl_info info;
int idx = 0;
int end_idx = 0;
for (int i = size - 1; i >= end_idx; --i) {
if (dladdr(call_stack[i], &info) && info.dli_sname) {
auto demangled = demangle(info.dli_sname);
std::string path(info.dli_fname);
// C++ traceback info are from core.so
if (path.substr(path.length() - 3).compare(".so") == 0) {
sout << idx++ << " " << demangled << "\n";
}
}
}
free(symbols);
#else
sout << "Not support stack backtrace yet.\n";
#endif
return sout.str();
}
static std::string GetTraceBackString(const std::string& what,
const char* file,
int line) {
return GetCurrentTraceBackString() + GetErrorSumaryString(what, file, line);
}
struct EnforceNotMet : public std::exception {
public:
EnforceNotMet(const std::string& str, const char* file, int line)
: err_str_(GetTraceBackString(str, file, line)) {}
const char* what() const noexcept override { return err_str_.c_str(); }
private:
std::string err_str_;
};
#define CINN_THROW(...) \
do { \
try { \
throw enforce::EnforceNotMet(__VA_ARGS__, __FILE__, __LINE__); \
} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
} \
} while (0)
} // namespace enforce
/**
* This handler is dealing with the errors happen in in the current
* Scheduling.
*/
class IRScheduleErrorHandler {
class IRScheduleErrorHandler : public utils::ErrorHandler {
public:
/**
* \brief constructor
* \param err_msg the error message
*/
explicit IRScheduleErrorHandler(const std::string& err_msg,
explicit IRScheduleErrorHandler(const std::string& primitive,
const std::string& err_msg,
const ModuleExpr& module_expr)
: err_msg_(err_msg), module_expr_(module_expr) {}
: primitive_(primitive), err_msg_(err_msg), module_expr_(module_expr) {}
/**
* \brief Returns a short error message corresponding to the kGeneral error
......@@ -151,25 +46,11 @@ class IRScheduleErrorHandler {
*/
std::string DetailedErrorMessage() const;
/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std::string FormatErrorMessage(
const std::string& primitive,
const ScheduleErrorMessageLevel& err_msg_level) const;
private:
ModuleExpr module_expr_;
std::string primitive_;
std::string err_msg_;
ModuleExpr module_expr_;
};
std::string NegativeFactorErrorMessage(const int64_t& factor,
const size_t& idx);
std::string InferFactorErrorMessage();
std::string FactorProductErrorMessage();
} // namespace ir
} // namespace cinn
......@@ -222,6 +222,7 @@ void ReplaceExpr(Expr* source,
std::vector<int> ValidateFactors(const std::vector<int>& factors,
int total_extent,
const ModuleExpr& module_expr) {
const std::string primitive = "split";
CHECK(!factors.empty())
<< "The factors param of Split should not be empty! Please check.";
bool has_minus_one = false;
......@@ -230,11 +231,19 @@ std::vector<int> ValidateFactors(const std::vector<int>& factors,
for (auto& i : factors) {
idx++;
if (i == 0 || i < -1) {
throw IRScheduleErrorHandler(NegativeFactorErrorMessage(i, idx),
module_expr);
std::ostringstream os;
os << "The params in factors of Split should be positive. However, the "
"factor at position "
<< idx << " is " << i << std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
} else if (i == -1) {
if (has_minus_one) {
throw IRScheduleErrorHandler(InferFactorErrorMessage(), module_expr);
std::ostringstream os;
os << "The params in factors of Split should not be less than -1 or "
"have "
"more than one -1!"
<< std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
}
has_minus_one = true;
} else {
......@@ -244,12 +253,20 @@ std::vector<int> ValidateFactors(const std::vector<int>& factors,
std::vector<int> validated_factors = factors;
if (!has_minus_one) {
if (product < total_extent) {
throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr);
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal "
"to original loop's extent!"
<< std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
}
return validated_factors;
} else {
if (product > total_extent) {
throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr);
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal "
"to original loop's extent!"
<< std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
}
int minus_one_candidate = static_cast<int>(
ceil(static_cast<double>(total_extent) / static_cast<double>(product)));
......
......@@ -164,8 +164,8 @@ DEFINE_int32(cinn_profiler_state,
"Specify the ProfilerState by Int in CINN, 0 for kDisabled, 1 for "
"kCPU, 2 for kCUDA, 3 for kAll, default 0.");
DEFINE_int32(cinn_schedule_error_message_level,
Int32FromEnv("FLAGS_cinn_schedule_error_message_level", 0),
DEFINE_int32(cinn_error_message_level,
Int32FromEnv("FLAGS_cinn_error_message_level", 0),
"Specify the level of printing error message in the schedule."
"0 means short, 1 means detailed.");
......
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,4 +14,19 @@
#include "paddle/cinn/utils/error.h"
namespace cinn::utils {} // namespace cinn::utils
namespace cinn {
namespace utils {
std::string ErrorHandler::FormatErrorMessage(
const ErrorMessageLevel& err_msg_level) const {
std::ostringstream os;
std::string err_msg = err_msg_level == ErrorMessageLevel::kDetailed
? DetailedErrorMessage()
: GeneralErrorMessage();
os << "[Error info] " << err_msg;
return os.str();
}
} // namespace utils
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -13,12 +13,159 @@
// limitations under the License.
#pragma once
//! This file includes some utilities imported from LLVM.
#include "llvm/Support/Error.h"
namespace cinn::utils {
#ifdef __GNUC__
#include <cxxabi.h> // for __cxa_demangle
#endif // __GNUC__
template <typename T>
using Expected = llvm::Expected<T>;
#if !defined(_WIN32)
#include <dlfcn.h> // dladdr
#include <unistd.h> // sleep, usleep
#else // _WIN32
#ifndef NOMINMAX
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#endif
#include <windows.h> // GetModuleFileName, Sleep
#endif
} // namespace cinn::utils
#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
#include <execinfo.h>
#endif
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
namespace cinn {
namespace utils {
namespace enforce {
#ifdef __GNUC__
inline std::string demangle(std::string name) {
int status = -4; // some arbitrary value to eliminate the compiler warning
std::unique_ptr<char, void (*)(void*)> res{
abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free};
return (status == 0) ? res.get() : name;
}
#else
inline std::string demangle(std::string name) { return name; }
#endif
static std::string GetErrorSumaryString(const std::string& what,
const char* file,
int line) {
std::ostringstream sout;
sout << "\n----------------------\nError Message "
"Summary:\n----------------------\n";
sout << what << "(at " << file << " : " << line << ")" << std::endl;
return sout.str();
}
static std::string GetCurrentTraceBackString() {
std::ostringstream sout;
sout << "\n\n--------------------------------------\n";
sout << "C++ Traceback (most recent call last):";
sout << "\n--------------------------------------\n";
#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
static constexpr int TRACE_STACK_LIMIT = 100;
void* call_stack[TRACE_STACK_LIMIT];
auto size = backtrace(call_stack, TRACE_STACK_LIMIT);
auto symbols = backtrace_symbols(call_stack, size);
Dl_info info;
int idx = 0;
int end_idx = 0;
for (int i = size - 1; i >= end_idx; --i) {
if (dladdr(call_stack[i], &info) && info.dli_sname) {
auto demangled = demangle(info.dli_sname);
std::string path(info.dli_fname);
// C++ traceback info are from core.so
if (path.substr(path.length() - 3).compare(".so") == 0) {
sout << idx++ << " " << demangled << "\n";
}
}
}
free(symbols);
#else
sout << "Not support stack backtrace yet.\n";
#endif
return sout.str();
}
static std::string GetTraceBackString(const std::string& what,
const char* file,
int line) {
return GetCurrentTraceBackString() + GetErrorSumaryString(what, file, line);
}
struct EnforceNotMet : public std::exception {
public:
EnforceNotMet(const std::string& str, const char* file, int line)
: err_str_(GetTraceBackString(str, file, line)) {}
const char* what() const noexcept override { return err_str_.c_str(); }
private:
std::string err_str_;
};
#ifdef PADDLE_THROW
#define CINN_THROW PADDLE_THROW
#else
#define CINN_THROW(...) \
do { \
try { \
throw utils::enforce::EnforceNotMet(__VA_ARGS__, __FILE__, __LINE__); \
} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
} \
} while (0)
#endif
} // namespace enforce
/**
* \brief Indicates the level of printing error message in the current
* operation
*/
enum class ErrorMessageLevel : int32_t {
/** \brief Print an error message in short mode.
* Short mode shows which and where the error happens*/
kGeneral = 0,
/** \brief Print an error message in detailed mode.
* Detailed mode shows which and where the error happens, and the
* detailed input parameters.
*/
kDetailed = 1,
};
/**
* This handler is a base class dealing with the errors happen in in the current
* operation.
*/
class ErrorHandler {
public:
/**
* \brief Returns a short error message corresponding to the kGeneral error
* level.
*/
virtual std::string GeneralErrorMessage() const = 0;
/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
virtual std::string DetailedErrorMessage() const = 0;
/**
* \brief Format the error message.
*/
std::string FormatErrorMessage(const ErrorMessageLevel& err_msg_level) const;
};
} // namespace utils
} // namespace cinn
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册