Enhance Base Codegen to do clean up when generation fails

This closes #764
上级 5844fa62
......@@ -159,9 +159,8 @@ bool ExecVariableListCodegen::GenerateExecVariableList(
}
// So looks like we're going to generate code
llvm::Function* ExecVariableList_func = codegen_utils->
CreateFunction<ExecVariableListFn>(
GetUniqueFuncName());
llvm::Function* ExecVariableList_func = CreateFunction<ExecVariableListFn>(
codegen_utils, GetUniqueFuncName());
auto irb = codegen_utils->ir_builder();
......
......@@ -17,6 +17,8 @@
#include "codegen/utils/codegen_utils.h"
#include "codegen/codegen_interface.h"
#include "llvm/IR/Function.h"
namespace gpcodegen {
/** \addtogroup gpcodegen
......@@ -43,6 +45,16 @@ class BaseCodegen: public CodegenInterface {
bool GenerateCode(gpcodegen::CodegenUtils* codegen_utils) final {
is_generated_ = GenerateCodeInternal(codegen_utils);
if (!is_generated_) {
// If failed to generate, make sure we do clean up
// by erasing all the llvm functions.
for (llvm::Function* function : uncompiled_generated_functions_) {
assert(nullptr != function);
function->eraseFromParent();
}
}
// We don't need to keep these pointers any more
std::vector<llvm::Function*>().swap(uncompiled_generated_functions_);
return is_generated_;
}
......@@ -147,12 +159,39 @@ class BaseCodegen: public CodegenInterface {
**/
virtual bool GenerateCodeInternal(gpcodegen::CodegenUtils* codegen_utils) = 0;
/**
* @brief Create llvm Function for given type and store the function pointer
* in vector
*
* @note If generation fails, this class takes responsibility to clean up all
* the functions it created
*
* @tparam FunctionType Type of the function to create
*
* @param codegen_utils Utility to ease the code generation process.
* @param function_name Name of the function to create
* @return llvm::Function pointer
**/
template <typename FunctionType>
llvm::Function* CreateFunction(gpcodegen::CodegenUtils* codegen_utils,
const std::string& function_name) {
assert(nullptr != codegen_utils);
llvm::Function* function = codegen_utils->CreateFunction<FunctionType>(
function_name);
assert(nullptr != function);
uncompiled_generated_functions_.push_back(function);
return function;
}
private:
std::string orig_func_name_;
std::string unique_func_name_;
FuncPtrType regular_func_ptr_;
FuncPtrType* ptr_to_chosen_func_ptr_;
bool is_generated_;
// To track uncompiled llvm functions it creates and erase from
// llvm module on failed generations.
std::vector<llvm::Function*> uncompiled_generated_functions_;
};
/** @} */
} // namespace gpcodegen
......
......@@ -83,7 +83,7 @@ class SumCodeGenerator : public BaseCodegen<SumFunc> {
protected:
bool GenerateCodeInternal(gpcodegen::CodegenUtils* codegen_utils) final {
llvm::Function* add2_func
= codegen_utils->CreateFunction<SumFunc>(GetUniqueFuncName());
= CreateFunction<SumFunc>(codegen_utils, GetUniqueFuncName());
llvm::BasicBlock* add2_body = codegen_utils->CreateBasicBlock("body",
add2_func);
codegen_utils->ir_builder()->SetInsertPoint(add2_body);
......@@ -119,6 +119,7 @@ class FailingCodeGenerator : public BaseCodegen<SumFunc> {
static constexpr char kFailingFuncNamePrefix[] = "SumFuncFailing";
};
template <bool GEN_SUCCESS>
class UncompilableCodeGenerator : public BaseCodegen<UncompilableFunc> {
public:
explicit UncompilableCodeGenerator(
......@@ -134,14 +135,14 @@ class UncompilableCodeGenerator : public BaseCodegen<UncompilableFunc> {
protected:
bool GenerateCodeInternal(gpcodegen::CodegenUtils* codegen_utils) final {
llvm::Function* dummy_func
= codegen_utils->CreateFunction<UncompilableFunc>(
= CreateFunction<UncompilableFunc>(codegen_utils,
GetUniqueFuncName());
llvm::BasicBlock* dummy_func_body = codegen_utils->CreateBasicBlock("body",
dummy_func);
codegen_utils->ir_builder()->SetInsertPoint(dummy_func_body);
llvm::Value* int_value = codegen_utils->GetConstant(4);
codegen_utils->ir_builder()->CreateRet(int_value);
return true;
return GEN_SUCCESS;
}
private:
......@@ -150,7 +151,9 @@ class UncompilableCodeGenerator : public BaseCodegen<UncompilableFunc> {
constexpr char SumCodeGenerator::kAddFuncNamePrefix[];
constexpr char FailingCodeGenerator::kFailingFuncNamePrefix[];
constexpr char UncompilableCodeGenerator::kUncompilableFuncNamePrefix[];
template <bool GEN_SUCCESS>
constexpr char
UncompilableCodeGenerator<GEN_SUCCESS>::kUncompilableFuncNamePrefix[];
// Test environment to handle global per-process initialization tasks for all
// tests.
......@@ -220,7 +223,7 @@ TEST_F(CodegenManagerTest, GenerateCodeTest) {
// Test if generation pass with UncompiledCodeGenerator
uncompilable_func_ptr = nullptr;
EnrollCodegen<UncompilableCodeGenerator, UncompilableFunc>(
EnrollCodegen<UncompilableCodeGenerator<true>, UncompilableFunc>(
UncompilableFuncRegular, &uncompilable_func_ptr);
EXPECT_EQ(2, manager_->GenerateCode());
}
......@@ -265,6 +268,97 @@ TEST_F(CodegenManagerTest, PrepareGeneratedFunctionsNoCompilationErrorTest) {
ASSERT_TRUE(SumFuncRegular == failed_func_ptr);
}
TEST_F(CodegenManagerTest, UnCompilableFailedGenerationTest) {
// Test if generation happens successfully
sum_func_ptr = nullptr;
EnrollCodegen<SumCodeGenerator, SumFunc>(SumFuncRegular, &sum_func_ptr);
EXPECT_EQ(1, manager_->GenerateCode());
// Test if generation fails with FailingCodeGenerator
failed_func_ptr = nullptr;
EnrollCodegen<FailingCodeGenerator, SumFunc>(SumFuncRegular,
&failed_func_ptr);
// Create uncompilable generator which fails in generation
// and produce broken function
uncompilable_func_ptr = nullptr;
EnrollCodegen<UncompilableCodeGenerator<false>, UncompilableFunc>(
UncompilableFuncRegular, &uncompilable_func_ptr);
EXPECT_EQ(1, manager_->GenerateCode());
// Make sure the function pointers refer to regular versions
ASSERT_TRUE(SumFuncRegular == sum_func_ptr);
ASSERT_TRUE(SumFuncRegular == failed_func_ptr);
ASSERT_TRUE(UncompilableFuncRegular == uncompilable_func_ptr);
// This should update function pointers to generated version,
// if generation was successful
ASSERT_TRUE(manager_->PrepareGeneratedFunctions());
// For sum_func_ptr, we successfully generated code.
// So, pointer should reflect that.
ASSERT_TRUE(SumFuncRegular != sum_func_ptr);
// For failed_func_ptr, code generation was unsuccessful.
// So, pointer should not change.
ASSERT_TRUE(SumFuncRegular == failed_func_ptr);
// For uncompilable_func_ptr, code generation was unsuccessful.
// So, pointer should not change.
ASSERT_TRUE(UncompilableFuncRegular == uncompilable_func_ptr);
// Check generate SumFuncRegular works as expected;
EXPECT_EQ(3, sum_func_ptr(1, 2));
// Reset the manager, so that all the code generators go away
manager_.reset(nullptr);
// The manager reset should have restored all the function pointers
// to point to regular version
ASSERT_TRUE(SumFuncRegular == sum_func_ptr);
ASSERT_TRUE(SumFuncRegular == failed_func_ptr);
ASSERT_TRUE(UncompilableFuncRegular == uncompilable_func_ptr);
}
TEST_F(CodegenManagerTest, UnCompilablePassedGenerationTest) {
// Test if generation happens successfully
sum_func_ptr = nullptr;
EnrollCodegen<SumCodeGenerator, SumFunc>(SumFuncRegular, &sum_func_ptr);
EXPECT_EQ(1, manager_->GenerateCode());
// Test if generation fails with FailingCodeGenerator
failed_func_ptr = nullptr;
EnrollCodegen<FailingCodeGenerator, SumFunc>(SumFuncRegular,
&failed_func_ptr);
// Create uncompilable generator which generate broken
// function and return success status on generation
uncompilable_func_ptr = nullptr;
EnrollCodegen<UncompilableCodeGenerator<true>, UncompilableFunc>(
UncompilableFuncRegular, &uncompilable_func_ptr);
EXPECT_EQ(2, manager_->GenerateCode());
// Make sure both the function pointers refer to regular versions
ASSERT_TRUE(SumFuncRegular == sum_func_ptr);
ASSERT_TRUE(SumFuncRegular == failed_func_ptr);
ASSERT_TRUE(UncompilableFuncRegular == uncompilable_func_ptr);
// This should cause program to exit because of
// broken function
EXPECT_DEATH(manager_->PrepareGeneratedFunctions(), "");
// Reset the manager, so that all the code generators go away
manager_.reset(nullptr);
// The manager reset should have restored all the function pointers
// to point to regular version
ASSERT_TRUE(SumFuncRegular == sum_func_ptr);
ASSERT_TRUE(SumFuncRegular == failed_func_ptr);
ASSERT_TRUE(UncompilableFuncRegular == uncompilable_func_ptr);
}
TEST_F(CodegenManagerTest, ResetTest) {
sum_func_ptr = nullptr;
SumCodeGenerator* code_gen = new SumCodeGenerator(SumFuncRegular,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册