未验证 提交 5f05b22b 编写于 作者: Z Zhang Zheng 提交者: GitHub

Cinn schedule error (#54983)

* [CINN] Schedule error message optimization

* format code style

* add test

* fix format

* using CINN_THROW and using flags

* optimize error msg

* do not use abtract class of error hanlder

* fix header
上级 0fd6efbb
......@@ -25,6 +25,7 @@
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_schedule_error.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/remove_schedule_block.h"
......@@ -156,6 +157,48 @@ void test_split_and_fuse2(void* _args, int32_t num_args)
ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code));
}
void TestSplitThrow() {
Context::Global().ResetNameId();
Expr M(32);
Expr N(32);
Expr P(32);
Target target = common::DefaultHostTarget();
Placeholder<float> A("A", {M, N});
auto B = Compute(
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
auto stages = CreateStages({A, B});
auto func = cinn::lang::LowerVec(
"test_split_throw", stages, {A, B}, {}, {}, nullptr, target, true);
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(
mod_expr, -1, false, ir::ScheduleErrorMessageLevel::kGeneral);
auto fused = ir_sch.Fuse("B", {0, 1});
// statement that cause the exception
auto splited = ir_sch.Split(fused, {-1, -1});
auto loops = ir_sch.GetLoops("B");
fused = ir_sch.Fuse(loops);
splited = ir_sch.Split(fused, {256, -1});
Module::Builder builder("module1", target);
for (auto& i : func) {
builder.AddFunction(i);
}
auto module = builder.Build();
CodeGenC codegen(target);
codegen.SetInlineBuiltinCodes(false);
auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl);
}
TEST(IrSchedule, split_throw) {
ASSERT_THROW(TestSplitThrow(), ir::enforce::EnforceNotMet);
}
TEST(IrSchedule, reorder1) {
Context::Global().ResetNameId();
Expr M(32);
......
......@@ -8,6 +8,7 @@ gather_srcs(
ir.cc
ir_base.cc
ir_schedule.cc
ir_schedule_error.cc
ir_schedule_util.cc
ir_visitor.cc
ir_printer.cc
......
......@@ -33,6 +33,7 @@
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_operators.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_schedule_error.h"
#include "paddle/cinn/ir/ir_schedule_util.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/lang/compute.h"
......@@ -41,6 +42,8 @@
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/string.h"
DECLARE_int32(cinn_schedule_error_message_level);
namespace cinn {
namespace ir {
......@@ -50,8 +53,15 @@ namespace ir {
class ScheduleImpl {
public:
ScheduleImpl() = default;
explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false)
: module_expr_(module_expr), debug_flag_(debug_flag) {}
explicit ScheduleImpl(const ModuleExpr& module_expr,
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral)
: module_expr_(module_expr), debug_flag_(debug_flag) {
err_msg_level_ = static_cast<ScheduleErrorMessageLevel>(
FLAGS_cinn_schedule_error_message_level ||
static_cast<int>(err_msg_level));
}
explicit ScheduleImpl(ModuleExpr&& module_expr)
: module_expr_(std::move(module_expr)) {}
......@@ -129,8 +139,26 @@ class ScheduleImpl {
ModuleExpr module_expr_;
bool debug_flag_{false};
ScheduleErrorMessageLevel err_msg_level_ =
ScheduleErrorMessageLevel::kGeneral;
};
/** \brief A macro that guards the beginning of each implementation of schedule
*/
#define CINN_IR_SCHEDULE_BEGIN() try {
/**
* \brief A macro that pairs with `CINN_IR_SCHEDULE_BEGIN`, handling potential
* errors and error message printing.
* @param primitive A string representing the kind of schedule primitive.
* @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message
* printing
*/
#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \
} \
catch (const IRScheduleErrorHandler& err_hanlder) { \
CINN_THROW(err_hanlder.FormatErrorMessage(primitive, err_msg_level)); \
}
std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
const std::vector<int>& factors) {
CHECK(loop.As<ir::For>())
......@@ -147,7 +175,10 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
<< ") at loop:\n"
<< loop;
auto processed_factors = ValidateFactors(factors, tot_extent);
std::vector<int> processed_factors;
CINN_IR_SCHEDULE_BEGIN();
processed_factors = ValidateFactors(factors, tot_extent, this->module_expr_);
CINN_IR_SCHEDULE_END("split", this->err_msg_level_);
int prod_size = std::accumulate(processed_factors.begin(),
processed_factors.end(),
1,
......@@ -1194,7 +1225,6 @@ struct LoopReconstructor : public ir::IRMutator<> {
return utils::Join(new_var_names, ",");
}
private:
public:
/*! \brief The root block */
Expr root_;
......@@ -2286,8 +2316,10 @@ IRSchedule::IRSchedule() {}
IRSchedule::IRSchedule(const ModuleExpr& module_expr,
utils::LinearRandomEngine::StateType rand_seed,
bool debug_flag) {
impl_ = std::make_unique<ScheduleImpl>(module_expr, debug_flag);
bool debug_flag,
ScheduleErrorMessageLevel err_msg_level) {
impl_ =
std::make_unique<ScheduleImpl>(module_expr, debug_flag, err_msg_level);
this->InitSeed(rand_seed);
}
......
......@@ -29,6 +29,20 @@
namespace cinn {
namespace ir {
/**
* \brief Indicates the level of printing error message in the current Schedule
*/
enum class ScheduleErrorMessageLevel : int32_t {
/** \brief Print an error message in short mode.
* Short mode shows which and where the schedule error happens*/
kGeneral = 0,
/** \brief Print an error message in detailed mode.
* Detailed mode shows which and where the schedule error happens, and the
* schedule input parameters.
*/
kDetailed = 1,
};
/**
* A struct representing a module that contains Expr. This struct is only used
* in Schedule process.
......@@ -70,7 +84,9 @@ class IRSchedule {
IRSchedule();
explicit IRSchedule(const ModuleExpr& modexpr,
utils::LinearRandomEngine::StateType rand_seed = -1,
bool debug_flag = false);
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral);
IRSchedule(ir::ModuleExpr&& mod_expr,
ScheduleDesc&& trace,
utils::LinearRandomEngine::StateType rand_seed = -1);
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/ir_schedule_error.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace cinn {
namespace ir {
std::string IRScheduleErrorHandler::GeneralErrorMessage() const {
return this->err_msg_;
}
std::string IRScheduleErrorHandler::DetailedErrorMessage() const {
std::ostringstream os;
os << GeneralErrorMessage();
os << "[Expr info] The Expr of current schedule is: "
<< this->module_expr_.GetExprs() << std::endl;
return os.str();
}
std::string IRScheduleErrorHandler::FormatErrorMessage(
const std::string& primitive,
const ScheduleErrorMessageLevel& err_msg_level) const {
std::ostringstream os;
std::string err_msg = err_msg_level == ScheduleErrorMessageLevel::kDetailed
? DetailedErrorMessage()
: GeneralErrorMessage();
os << "[IRScheduleError] An error occurred in the scheduel primitive <"
<< primitive << ">. " << std::endl;
os << "[Error info] " << err_msg;
return os.str();
}
std::string NegativeFactorErrorMessage(const int64_t& factor,
const size_t& idx) {
std::ostringstream os;
os << "The params in factors of Split should be positive. However, the "
"factor at position "
<< idx << " is " << factor << std::endl;
return os.str();
}
std::string InferFactorErrorMessage() {
std::ostringstream os;
os << "The params in factors of Split should not be less than -1 or have "
"more than one -1!"
<< std::endl;
return os.str();
}
std::string FactorProductErrorMessage() {
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal "
"to original loop's extent!"
<< std::endl;
return os.str();
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef __GNUC__
#include <cxxabi.h> // for __cxa_demangle
#endif // __GNUC__
#if !defined(_WIN32)
#include <dlfcn.h> // dladdr
#include <unistd.h> // sleep, usleep
#else // _WIN32
#ifndef NOMINMAX
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#endif
#include <windows.h> // GetModuleFileName, Sleep
#endif
#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
#include <execinfo.h>
#endif
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/cinn/ir/ir_schedule.h"
namespace cinn {
namespace ir {
namespace enforce {
#ifdef __GNUC__
inline std::string demangle(std::string name) {
int status = -4; // some arbitrary value to eliminate the compiler warning
std::unique_ptr<char, void (*)(void*)> res{
abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free};
return (status == 0) ? res.get() : name;
}
#else
inline std::string demangle(std::string name) { return name; }
#endif
static std::string GetErrorSumaryString(const std::string& what,
const char* file,
int line) {
std::ostringstream sout;
sout << "\n----------------------\nError Message "
"Summary:\n----------------------\n";
sout << what << "(at " << file << " : " << line << ")" << std::endl;
return sout.str();
}
static std::string GetCurrentTraceBackString() {
std::ostringstream sout;
sout << "\n\n--------------------------------------\n";
sout << "C++ Traceback (most recent call last):";
sout << "\n--------------------------------------\n";
#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
static constexpr int TRACE_STACK_LIMIT = 100;
void* call_stack[TRACE_STACK_LIMIT];
auto size = backtrace(call_stack, TRACE_STACK_LIMIT);
auto symbols = backtrace_symbols(call_stack, size);
Dl_info info;
int idx = 0;
int end_idx = 0;
for (int i = size - 1; i >= end_idx; --i) {
if (dladdr(call_stack[i], &info) && info.dli_sname) {
auto demangled = demangle(info.dli_sname);
std::string path(info.dli_fname);
// C++ traceback info are from core.so
if (path.substr(path.length() - 3).compare(".so") == 0) {
sout << idx++ << " " << demangled << "\n";
}
}
}
free(symbols);
#else
sout << "Not support stack backtrace yet.\n";
#endif
return sout.str();
}
static std::string GetTraceBackString(const std::string& what,
const char* file,
int line) {
return GetCurrentTraceBackString() + GetErrorSumaryString(what, file, line);
}
struct EnforceNotMet : public std::exception {
public:
EnforceNotMet(const std::string& str, const char* file, int line)
: err_str_(GetTraceBackString(str, file, line)) {}
const char* what() const noexcept override { return err_str_.c_str(); }
private:
std::string err_str_;
};
#define CINN_THROW(...) \
do { \
try { \
throw enforce::EnforceNotMet(__VA_ARGS__, __FILE__, __LINE__); \
} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
} \
} while (0)
} // namespace enforce
/**
* This handler is dealing with the errors happen in in the current
* Scheduling.
*/
class IRScheduleErrorHandler {
public:
/**
* \brief constructor
* \param err_msg the error message
*/
explicit IRScheduleErrorHandler(const std::string& err_msg,
const ModuleExpr& module_expr)
: err_msg_(err_msg), module_expr_(module_expr) {}
/**
* \brief Returns a short error message corresponding to the kGeneral error
* level.
*/
std::string GeneralErrorMessage() const;
/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std::string DetailedErrorMessage() const;
/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std::string FormatErrorMessage(
const std::string& primitive,
const ScheduleErrorMessageLevel& err_msg_level) const;
private:
ModuleExpr module_expr_;
std::string err_msg_;
};
std::string NegativeFactorErrorMessage(const int64_t& factor,
const size_t& idx);
std::string InferFactorErrorMessage();
std::string FactorProductErrorMessage();
} // namespace ir
} // namespace cinn
......@@ -220,19 +220,22 @@ void ReplaceExpr(Expr* source,
}
std::vector<int> ValidateFactors(const std::vector<int>& factors,
int total_extent) {
int total_extent,
const ModuleExpr& module_expr) {
CHECK(!factors.empty())
<< "The factors param of Split should not be empty! Please check.";
bool has_minus_one = false;
int product = 1;
int idx = -1;
for (auto& i : factors) {
CHECK(i != 0)
<< "The params in factors of Split should not be 0! Please check.";
CHECK(i >= -1) << "The params in factors of Split should not be less than "
"-1! Please check.";
if (i == -1) {
CHECK(!has_minus_one) << "The params in factors of Split should not have "
"more than one -1! Please check.";
idx++;
if (i == 0 || i < -1) {
throw IRScheduleErrorHandler(NegativeFactorErrorMessage(i, idx),
module_expr);
} else if (i == -1) {
if (has_minus_one) {
throw IRScheduleErrorHandler(InferFactorErrorMessage(), module_expr);
}
has_minus_one = true;
} else {
product *= i;
......@@ -240,15 +243,14 @@ std::vector<int> ValidateFactors(const std::vector<int>& factors,
}
std::vector<int> validated_factors = factors;
if (!has_minus_one) {
CHECK_GE(product, total_extent)
<< "In Split, the factors' product should be equal to original loop's "
"extent! Please check.";
if (product < total_extent) {
throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr);
}
return validated_factors;
} else {
CHECK_LE(product, total_extent)
<< "In Split, when there is -1 in factors, the other factors' product "
"should be <= "
"original loop's extent! Please check.";
if (product > total_extent) {
throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr);
}
int minus_one_candidate = static_cast<int>(
ceil(static_cast<double>(total_extent) / static_cast<double>(product)));
for (int i = 0; i < validated_factors.size(); ++i) {
......
......@@ -23,6 +23,7 @@
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_schedule_error.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/utils/random_engine.h"
#include "paddle/cinn/utils/string.h"
......@@ -248,7 +249,8 @@ void ReplaceExpr(Expr* source,
* @return return The valiated factors.
*/
std::vector<int> ValidateFactors(const std::vector<int>& factors,
int total_extent);
int total_extent,
const ModuleExpr& module_expr);
void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis);
......
......@@ -164,6 +164,11 @@ DEFINE_int32(cinn_profiler_state,
"Specify the ProfilerState by Int in CINN, 0 for kDisabled, 1 for "
"kCPU, 2 for kCUDA, 3 for kAll, default 0.");
DEFINE_int32(cinn_schedule_error_message_level,
Int32FromEnv("FLAGS_cinn_schedule_error_message_level", 0),
"Specify the level of printing error message in the schedule."
"0 means short, 1 means detailed.");
namespace cinn {
namespace runtime {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册