diff --git a/paddle/cinn/backends/ir_schedule_test.cc b/paddle/cinn/backends/ir_schedule_test.cc index d610df1990d10d73e2cfc63518aeacf6592e87d3..fa2b7b7299891dd3c55dd9ef1e676e51272f3368 100644 --- a/paddle/cinn/backends/ir_schedule_test.cc +++ b/paddle/cinn/backends/ir_schedule_test.cc @@ -177,7 +177,7 @@ void TestSplitThrow() { std::vector vec_ast{ast_expr}; ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch( - mod_expr, -1, false, ir::ScheduleErrorMessageLevel::kGeneral); + mod_expr, -1, false, utils::ErrorMessageLevel::kGeneral); auto fused = ir_sch.Fuse("B", {0, 1}); // statement that cause the exception auto splited = ir_sch.Split(fused, {-1, -1}); @@ -196,7 +196,7 @@ void TestSplitThrow() { auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); } TEST(IrSchedule, split_throw) { - ASSERT_THROW(TestSplitThrow(), ir::enforce::EnforceNotMet); + ASSERT_THROW(TestSplitThrow(), utils::enforce::EnforceNotMet); } TEST(IrSchedule, reorder1) { diff --git a/paddle/cinn/ir/schedule/ir_schedule.cc b/paddle/cinn/ir/schedule/ir_schedule.cc index 88609c7a7eb9b2159ee3d258b1df95881a97f9e1..0c3632ffd22eb3c493e0243b5b5f3df4636e112c 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.cc +++ b/paddle/cinn/ir/schedule/ir_schedule.cc @@ -41,7 +41,7 @@ #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/utils/string.h" -DECLARE_int32(cinn_schedule_error_message_level); +DECLARE_int32(cinn_error_message_level); namespace cinn { namespace ir { @@ -54,12 +54,11 @@ class ScheduleImpl { ScheduleImpl() = default; explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false, - ScheduleErrorMessageLevel err_msg_level = - ScheduleErrorMessageLevel::kGeneral) + utils::ErrorMessageLevel err_msg_level = + utils::ErrorMessageLevel::kGeneral) : module_expr_(module_expr), debug_flag_(debug_flag) { - err_msg_level_ = static_cast( - FLAGS_cinn_schedule_error_message_level || - static_cast(err_msg_level)); + err_msg_level_ = static_cast( + FLAGS_cinn_error_message_level || static_cast(err_msg_level)); } explicit ScheduleImpl(ModuleExpr&& module_expr) : module_expr_(std::move(module_expr)) {} @@ -138,8 +137,7 @@ class ScheduleImpl { ModuleExpr module_expr_; bool debug_flag_{false}; - ScheduleErrorMessageLevel err_msg_level_ = - ScheduleErrorMessageLevel::kGeneral; + utils::ErrorMessageLevel err_msg_level_ = utils::ErrorMessageLevel::kGeneral; }; /** \brief A macro that guards the beginning of each implementation of schedule @@ -152,10 +150,10 @@ class ScheduleImpl { * @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message * printing */ -#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \ - } \ - catch (const IRScheduleErrorHandler& err_hanlder) { \ - CINN_THROW(err_hanlder.FormatErrorMessage(primitive, err_msg_level)); \ +#define CINN_IR_SCHEDULE_END(err_msg_level) \ + } \ + catch (const utils::ErrorHandler& err_hanlder) { \ + CINN_THROW(err_hanlder.FormatErrorMessage(err_msg_level)); \ } std::vector ScheduleImpl::Split(const Expr& loop, @@ -177,7 +175,7 @@ std::vector ScheduleImpl::Split(const Expr& loop, std::vector processed_factors; CINN_IR_SCHEDULE_BEGIN(); processed_factors = ValidateFactors(factors, tot_extent, this->module_expr_); - CINN_IR_SCHEDULE_END("split", this->err_msg_level_); + CINN_IR_SCHEDULE_END(this->err_msg_level_); int prod_size = std::accumulate(processed_factors.begin(), processed_factors.end(), 1, @@ -2316,7 +2314,7 @@ IRSchedule::IRSchedule() {} IRSchedule::IRSchedule(const ModuleExpr& module_expr, utils::LinearRandomEngine::StateType rand_seed, bool debug_flag, - ScheduleErrorMessageLevel err_msg_level) { + utils::ErrorMessageLevel err_msg_level) { impl_ = std::make_unique(module_expr, debug_flag, err_msg_level); this->InitSeed(rand_seed); diff --git a/paddle/cinn/ir/schedule/ir_schedule.h b/paddle/cinn/ir/schedule/ir_schedule.h index c36a3363c6dcc01d0a46b95abcff0f2f4e8aa542..4d36368603a8bbb7a6205114dbc669f7820c9008 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.h +++ b/paddle/cinn/ir/schedule/ir_schedule.h @@ -24,25 +24,12 @@ #include "paddle/cinn/ir/schedule/schedule_desc.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/ir/utils/ir_mutator.h" +#include "paddle/cinn/utils/error.h" #include "paddle/cinn/utils/random_engine.h" namespace cinn { namespace ir { -/** - * \brief Indicates the level of printing error message in the current Schedule - */ -enum class ScheduleErrorMessageLevel : int32_t { - /** \brief Print an error message in short mode. - * Short mode shows which and where the schedule error happens*/ - kGeneral = 0, - /** \brief Print an error message in detailed mode. - * Detailed mode shows which and where the schedule error happens, and the - * schedule input parameters. - */ - kDetailed = 1, -}; - /** * A struct representing a module that contains Expr. This struct is only used * in Schedule process. @@ -85,8 +72,8 @@ class IRSchedule { explicit IRSchedule(const ModuleExpr& modexpr, utils::LinearRandomEngine::StateType rand_seed = -1, bool debug_flag = false, - ScheduleErrorMessageLevel err_msg_level = - ScheduleErrorMessageLevel::kGeneral); + utils::ErrorMessageLevel err_msg_level = + utils::ErrorMessageLevel::kGeneral); IRSchedule(ir::ModuleExpr&& mod_expr, ScheduleDesc&& trace, utils::LinearRandomEngine::StateType rand_seed = -1); diff --git a/paddle/cinn/ir/schedule/ir_schedule_error.cc b/paddle/cinn/ir/schedule/ir_schedule_error.cc index 30c970ffd16d6abbeac2c0a461130bb966d9fcbc..fdccbaa36ef84b65f02437bcbced59a1f9ed9f9f 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_error.cc +++ b/paddle/cinn/ir/schedule/ir_schedule_error.cc @@ -20,7 +20,11 @@ namespace cinn { namespace ir { std::string IRScheduleErrorHandler::GeneralErrorMessage() const { - return this->err_msg_; + std::ostringstream os; + os << "[IRScheduleError] An error occurred in the scheduel primitive < " + << this->primitive_ << " >. " << std::endl; + os << this->err_msg_; + return os.str(); } std::string IRScheduleErrorHandler::DetailedErrorMessage() const { @@ -31,44 +35,5 @@ std::string IRScheduleErrorHandler::DetailedErrorMessage() const { return os.str(); } -std::string IRScheduleErrorHandler::FormatErrorMessage( - const std::string& primitive, - const ScheduleErrorMessageLevel& err_msg_level) const { - std::ostringstream os; - std::string err_msg = err_msg_level == ScheduleErrorMessageLevel::kDetailed - ? DetailedErrorMessage() - : GeneralErrorMessage(); - - os << "[IRScheduleError] An error occurred in the scheduel primitive <" - << primitive << ">. " << std::endl; - os << "[Error info] " << err_msg; - return os.str(); -} - -std::string NegativeFactorErrorMessage(const int64_t& factor, - const size_t& idx) { - std::ostringstream os; - os << "The params in factors of Split should be positive. However, the " - "factor at position " - << idx << " is " << factor << std::endl; - return os.str(); -} - -std::string InferFactorErrorMessage() { - std::ostringstream os; - os << "The params in factors of Split should not be less than -1 or have " - "more than one -1!" - << std::endl; - return os.str(); -} - -std::string FactorProductErrorMessage() { - std::ostringstream os; - os << "In Split, the factors' product should be not larger than or equal " - "to original loop's extent!" - << std::endl; - return os.str(); -} - } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/schedule/ir_schedule_error.h b/paddle/cinn/ir/schedule/ir_schedule_error.h index 9c4f15ddab962435b5910dde4529620f084480a1..1326bfd8852b0619c6ccd5da392bb9bb8943d99e 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_error.h +++ b/paddle/cinn/ir/schedule/ir_schedule_error.h @@ -14,130 +14,25 @@ #pragma once -#ifdef __GNUC__ -#include // for __cxa_demangle -#endif // __GNUC__ - -#if !defined(_WIN32) -#include // dladdr -#include // sleep, usleep -#else // _WIN32 -#ifndef NOMINMAX -#define NOMINMAX // msvc max/min macro conflict with std::min/max -#endif -#include // GetModuleFileName, Sleep -#endif - -#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL) -#include -#endif - -#include -#include -#include -#include -#include -#include #include "paddle/cinn/ir/schedule/ir_schedule.h" namespace cinn { namespace ir { -namespace enforce { - -#ifdef __GNUC__ -inline std::string demangle(std::string name) { - int status = -4; // some arbitrary value to eliminate the compiler warning - std::unique_ptr res{ - abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free}; - return (status == 0) ? res.get() : name; -} -#else -inline std::string demangle(std::string name) { return name; } -#endif - -static std::string GetErrorSumaryString(const std::string& what, - const char* file, - int line) { - std::ostringstream sout; - sout << "\n----------------------\nError Message " - "Summary:\n----------------------\n"; - sout << what << "(at " << file << " : " << line << ")" << std::endl; - return sout.str(); -} - -static std::string GetCurrentTraceBackString() { - std::ostringstream sout; - sout << "\n\n--------------------------------------\n"; - sout << "C++ Traceback (most recent call last):"; - sout << "\n--------------------------------------\n"; -#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL) - static constexpr int TRACE_STACK_LIMIT = 100; - - void* call_stack[TRACE_STACK_LIMIT]; - auto size = backtrace(call_stack, TRACE_STACK_LIMIT); - auto symbols = backtrace_symbols(call_stack, size); - Dl_info info; - int idx = 0; - int end_idx = 0; - for (int i = size - 1; i >= end_idx; --i) { - if (dladdr(call_stack[i], &info) && info.dli_sname) { - auto demangled = demangle(info.dli_sname); - std::string path(info.dli_fname); - // C++ traceback info are from core.so - if (path.substr(path.length() - 3).compare(".so") == 0) { - sout << idx++ << " " << demangled << "\n"; - } - } - } - free(symbols); -#else - sout << "Not support stack backtrace yet.\n"; -#endif - return sout.str(); -} - -static std::string GetTraceBackString(const std::string& what, - const char* file, - int line) { - return GetCurrentTraceBackString() + GetErrorSumaryString(what, file, line); -} - -struct EnforceNotMet : public std::exception { - public: - EnforceNotMet(const std::string& str, const char* file, int line) - : err_str_(GetTraceBackString(str, file, line)) {} - - const char* what() const noexcept override { return err_str_.c_str(); } - - private: - std::string err_str_; -}; - -#define CINN_THROW(...) \ - do { \ - try { \ - throw enforce::EnforceNotMet(__VA_ARGS__, __FILE__, __LINE__); \ - } catch (const std::exception& e) { \ - std::cout << e.what() << std::endl; \ - throw; \ - } \ - } while (0) -} // namespace enforce - /** * This handler is dealing with the errors happen in in the current * Scheduling. */ -class IRScheduleErrorHandler { +class IRScheduleErrorHandler : public utils::ErrorHandler { public: /** * \brief constructor * \param err_msg the error message */ - explicit IRScheduleErrorHandler(const std::string& err_msg, + explicit IRScheduleErrorHandler(const std::string& primitive, + const std::string& err_msg, const ModuleExpr& module_expr) - : err_msg_(err_msg), module_expr_(module_expr) {} + : primitive_(primitive), err_msg_(err_msg), module_expr_(module_expr) {} /** * \brief Returns a short error message corresponding to the kGeneral error @@ -151,25 +46,11 @@ class IRScheduleErrorHandler { */ std::string DetailedErrorMessage() const; - /** - * \brief Returns a detailed error message corresponding to the kDetailed - * error level. - */ - std::string FormatErrorMessage( - const std::string& primitive, - const ScheduleErrorMessageLevel& err_msg_level) const; - private: - ModuleExpr module_expr_; + std::string primitive_; std::string err_msg_; + ModuleExpr module_expr_; }; -std::string NegativeFactorErrorMessage(const int64_t& factor, - const size_t& idx); - -std::string InferFactorErrorMessage(); - -std::string FactorProductErrorMessage(); - } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/schedule/ir_schedule_util.cc b/paddle/cinn/ir/schedule/ir_schedule_util.cc index adfe5fdcef86147890f5c0da54d919241a16ff28..87b7147d97803820b9a25b903a6a2432961fe27f 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_util.cc +++ b/paddle/cinn/ir/schedule/ir_schedule_util.cc @@ -222,6 +222,7 @@ void ReplaceExpr(Expr* source, std::vector ValidateFactors(const std::vector& factors, int total_extent, const ModuleExpr& module_expr) { + const std::string primitive = "split"; CHECK(!factors.empty()) << "The factors param of Split should not be empty! Please check."; bool has_minus_one = false; @@ -230,11 +231,19 @@ std::vector ValidateFactors(const std::vector& factors, for (auto& i : factors) { idx++; if (i == 0 || i < -1) { - throw IRScheduleErrorHandler(NegativeFactorErrorMessage(i, idx), - module_expr); + std::ostringstream os; + os << "The params in factors of Split should be positive. However, the " + "factor at position " + << idx << " is " << i << std::endl; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr); } else if (i == -1) { if (has_minus_one) { - throw IRScheduleErrorHandler(InferFactorErrorMessage(), module_expr); + std::ostringstream os; + os << "The params in factors of Split should not be less than -1 or " + "have " + "more than one -1!" + << std::endl; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr); } has_minus_one = true; } else { @@ -244,12 +253,20 @@ std::vector ValidateFactors(const std::vector& factors, std::vector validated_factors = factors; if (!has_minus_one) { if (product < total_extent) { - throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr); + std::ostringstream os; + os << "In Split, the factors' product should be not larger than or equal " + "to original loop's extent!" + << std::endl; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr); } return validated_factors; } else { if (product > total_extent) { - throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr); + std::ostringstream os; + os << "In Split, the factors' product should be not larger than or equal " + "to original loop's extent!" + << std::endl; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr); } int minus_one_candidate = static_cast( ceil(static_cast(total_extent) / static_cast(product))); diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index 05b181a315a1861357a188efc72ba277f847ba88..524409d50c370bcd550a7f367a4e58f2858e8249 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -164,8 +164,8 @@ DEFINE_int32(cinn_profiler_state, "Specify the ProfilerState by Int in CINN, 0 for kDisabled, 1 for " "kCPU, 2 for kCUDA, 3 for kAll, default 0."); -DEFINE_int32(cinn_schedule_error_message_level, - Int32FromEnv("FLAGS_cinn_schedule_error_message_level", 0), +DEFINE_int32(cinn_error_message_level, + Int32FromEnv("FLAGS_cinn_error_message_level", 0), "Specify the level of printing error message in the schedule." "0 means short, 1 means detailed."); diff --git a/paddle/cinn/utils/error.cc b/paddle/cinn/utils/error.cc index e3920d45b0fcdeec88b9ca9eb915286d85e1cd7b..123c65d545fce73b25fdd7f863e2342e5ee4663a 100644 --- a/paddle/cinn/utils/error.cc +++ b/paddle/cinn/utils/error.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// Copyright (c) 2023 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,4 +14,19 @@ #include "paddle/cinn/utils/error.h" -namespace cinn::utils {} // namespace cinn::utils +namespace cinn { +namespace utils { + +std::string ErrorHandler::FormatErrorMessage( + const ErrorMessageLevel& err_msg_level) const { + std::ostringstream os; + std::string err_msg = err_msg_level == ErrorMessageLevel::kDetailed + ? DetailedErrorMessage() + : GeneralErrorMessage(); + + os << "[Error info] " << err_msg; + return os.str(); +} + +} // namespace utils +} // namespace cinn diff --git a/paddle/cinn/utils/error.h b/paddle/cinn/utils/error.h index f7c277e5a0698d5d5f5845ebe13c8c3724571000..118018d236773c19d1fac2ab591c0ba643edd0fe 100644 --- a/paddle/cinn/utils/error.h +++ b/paddle/cinn/utils/error.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// Copyright (c) 2023 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,12 +13,159 @@ // limitations under the License. #pragma once -//! This file includes some utilities imported from LLVM. -#include "llvm/Support/Error.h" -namespace cinn::utils { +#ifdef __GNUC__ +#include // for __cxa_demangle +#endif // __GNUC__ -template -using Expected = llvm::Expected; +#if !defined(_WIN32) +#include // dladdr +#include // sleep, usleep +#else // _WIN32 +#ifndef NOMINMAX +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#endif +#include // GetModuleFileName, Sleep +#endif -} // namespace cinn::utils +#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL) +#include +#endif + +#include +#include +#include +#include +#include +#include + +namespace cinn { +namespace utils { + +namespace enforce { + +#ifdef __GNUC__ +inline std::string demangle(std::string name) { + int status = -4; // some arbitrary value to eliminate the compiler warning + std::unique_ptr res{ + abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free}; + return (status == 0) ? res.get() : name; +} +#else +inline std::string demangle(std::string name) { return name; } +#endif + +static std::string GetErrorSumaryString(const std::string& what, + const char* file, + int line) { + std::ostringstream sout; + sout << "\n----------------------\nError Message " + "Summary:\n----------------------\n"; + sout << what << "(at " << file << " : " << line << ")" << std::endl; + return sout.str(); +} + +static std::string GetCurrentTraceBackString() { + std::ostringstream sout; + sout << "\n\n--------------------------------------\n"; + sout << "C++ Traceback (most recent call last):"; + sout << "\n--------------------------------------\n"; +#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL) + static constexpr int TRACE_STACK_LIMIT = 100; + + void* call_stack[TRACE_STACK_LIMIT]; + auto size = backtrace(call_stack, TRACE_STACK_LIMIT); + auto symbols = backtrace_symbols(call_stack, size); + Dl_info info; + int idx = 0; + int end_idx = 0; + for (int i = size - 1; i >= end_idx; --i) { + if (dladdr(call_stack[i], &info) && info.dli_sname) { + auto demangled = demangle(info.dli_sname); + std::string path(info.dli_fname); + // C++ traceback info are from core.so + if (path.substr(path.length() - 3).compare(".so") == 0) { + sout << idx++ << " " << demangled << "\n"; + } + } + } + free(symbols); +#else + sout << "Not support stack backtrace yet.\n"; +#endif + return sout.str(); +} + +static std::string GetTraceBackString(const std::string& what, + const char* file, + int line) { + return GetCurrentTraceBackString() + GetErrorSumaryString(what, file, line); +} + +struct EnforceNotMet : public std::exception { + public: + EnforceNotMet(const std::string& str, const char* file, int line) + : err_str_(GetTraceBackString(str, file, line)) {} + + const char* what() const noexcept override { return err_str_.c_str(); } + + private: + std::string err_str_; +}; + +#ifdef PADDLE_THROW +#define CINN_THROW PADDLE_THROW +#else +#define CINN_THROW(...) \ + do { \ + try { \ + throw utils::enforce::EnforceNotMet(__VA_ARGS__, __FILE__, __LINE__); \ + } catch (const std::exception& e) { \ + std::cout << e.what() << std::endl; \ + throw; \ + } \ + } while (0) +#endif +} // namespace enforce + +/** + * \brief Indicates the level of printing error message in the current + * operation + */ +enum class ErrorMessageLevel : int32_t { + /** \brief Print an error message in short mode. + * Short mode shows which and where the error happens*/ + kGeneral = 0, + /** \brief Print an error message in detailed mode. + * Detailed mode shows which and where the error happens, and the + * detailed input parameters. + */ + kDetailed = 1, +}; + +/** + * This handler is a base class dealing with the errors happen in in the current + * operation. + */ +class ErrorHandler { + public: + /** + * \brief Returns a short error message corresponding to the kGeneral error + * level. + */ + virtual std::string GeneralErrorMessage() const = 0; + + /** + * \brief Returns a detailed error message corresponding to the kDetailed + * error level. + */ + virtual std::string DetailedErrorMessage() const = 0; + + /** + * \brief Format the error message. + */ + std::string FormatErrorMessage(const ErrorMessageLevel& err_msg_level) const; +}; + +} // namespace utils +} // namespace cinn