提交 afef5c6c 编写于 作者: A Alexey Milovidov

Added stack protection; added a test

上级 c80aeb0e
......@@ -442,6 +442,7 @@ namespace ErrorCodes
extern const int CANNOT_PARSE_DWARF = 465;
extern const int INSECURE_PATH = 466;
extern const int CANNOT_PARSE_BOOL = 467;
extern const int CANNOT_PTHREAD_ATTR = 468;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;
......
#include <Common/checkStackSize.h>
#include <Common/Exception.h>
#include <ext/scope_guard.h>
#include <pthread.h>
#include <cstdint>
#include <sstream>
namespace DB
{
namespace ErrorCodes
{
extern const int CANNOT_PTHREAD_ATTR;
extern const int LOGICAL_ERROR;
extern const int TOO_DEEP_RECURSION;
}
}
static thread_local void * stack_address = nullptr;
static thread_local size_t max_stack_size = 0;
void checkStackSize()
{
using namespace DB;
if (!stack_address)
{
pthread_attr_t attr;
if (0 != pthread_getattr_np(pthread_self(), &attr))
throwFromErrno("Cannot pthread_getattr_np", ErrorCodes::CANNOT_PTHREAD_ATTR);
SCOPE_EXIT({ pthread_attr_destroy(&attr); });
if (0 != pthread_attr_getstack(&attr, &stack_address, &max_stack_size))
throwFromErrno("Cannot pthread_getattr_np", ErrorCodes::CANNOT_PTHREAD_ATTR);
}
const void * frame_address = __builtin_frame_address(0);
uintptr_t int_frame_address = reinterpret_cast<uintptr_t>(frame_address);
uintptr_t int_stack_address = reinterpret_cast<uintptr_t>(stack_address);
/// We assume that stack grows towards lower addresses. And that it starts to grow from the end of a chunk of memory of max_stack_size.
if (int_frame_address > int_stack_address + max_stack_size)
throw Exception("Logical error: frame address is greater than stack begin address", ErrorCodes::LOGICAL_ERROR);
size_t stack_size = int_stack_address + max_stack_size - int_frame_address;
/// Just check if we have already eat more than a half of stack size. It's a bit overkill (a half of stack size is wasted).
/// It's safe to assume that overflow in multiplying by two cannot occur.
if (stack_size * 2 > max_stack_size)
{
std::stringstream message;
message << "Stack size too large"
<< ". Stack address: " << stack_address
<< ", frame address: " << frame_address
<< ", stack size: " << stack_size
<< ", maximum stack size: " << max_stack_size;
throw Exception(message.str(), ErrorCodes::TOO_DEEP_RECURSION);
}
}
#pragma once
/** If the stack is large enough and is near its size, throw an exception.
* You can call this function in "heavy" functions that may be called recursively
* to prevent possible stack overflows.
*/
void checkStackSize();
......@@ -6,6 +6,7 @@
#include <Storages/StorageReplicatedMergeTree.h>
#include <Common/Exception.h>
#include <Common/ProfileEvents.h>
#include <Common/checkStackSize.h>
#include <TableFunctions/TableFunctionFactory.h>
#include <common/logger_useful.h>
......@@ -58,6 +59,8 @@ namespace
BlockInputStreamPtr createLocalStream(const ASTPtr & query_ast, const Context & context, QueryProcessingStage::Enum processed_stage)
{
checkStackSize();
InterpreterSelectQuery interpreter{query_ast, context, SelectQueryOptions(processed_stage)};
BlockInputStreamPtr stream = interpreter.execute().in;
......
......@@ -2,6 +2,7 @@
#include <IO/ReadBufferFromMemory.h>
#include <Common/typeid_cast.h>
#include <Common/checkStackSize.h>
#include <DataStreams/AddingDefaultBlockOutputStream.h>
#include <DataStreams/AddingDefaultsBlockInputStream.h>
......@@ -39,6 +40,7 @@ InterpreterInsertQuery::InterpreterInsertQuery(
const ASTPtr & query_ptr_, const Context & context_, bool allow_materialized_)
: query_ptr(query_ptr_), context(context_), allow_materialized(allow_materialized_)
{
checkStackSize();
}
......
......@@ -58,6 +58,7 @@
#include <Core/Types.h>
#include <Columns/Collator.h>
#include <Common/typeid_cast.h>
#include <Common/checkStackSize.h>
#include <Parsers/queryToString.h>
#include <ext/map.h>
#include <memory>
......@@ -211,6 +212,8 @@ InterpreterSelectQuery::InterpreterSelectQuery(
, input(input_)
, log(&Logger::get("InterpreterSelectQuery"))
{
checkStackSize();
initSettings();
const Settings & settings = context.getSettingsRef();
......
......@@ -23,6 +23,7 @@
#include <DataTypes/DataTypeString.h>
#include <Columns/ColumnString.h>
#include <Common/typeid_cast.h>
#include <Common/checkStackSize.h>
#include <Databases/IDatabase.h>
#include <Core/SettingsCommon.h>
#include <ext/range.h>
......@@ -387,6 +388,7 @@ StorageMerge::StorageListWithLocks StorageMerge::getSelectedTables(const ASTPtr
DatabaseIteratorPtr StorageMerge::getDatabaseIterator(const Context & context) const
{
checkStackSize();
auto database = context.getDatabase(source_database);
auto table_name_match = [this](const String & table_name_) { return table_name_regexp.match(table_name_); };
return database->getIterator(global_context, table_name_match);
......
DROP TABLE IF EXISTS merge1;
DROP TABLE IF EXISTS merge2;
CREATE TABLE IF NOT EXISTS merge1 (x UInt64) ENGINE = Merge(currentDatabase(), '^merge\\d$');
CREATE TABLE IF NOT EXISTS merge2 (x UInt64) ENGINE = Merge(currentDatabase(), '^merge\\d$');
SELECT * FROM merge1; -- { serverError 306 }
SELECT * FROM merge2; -- { serverError 306 }
DROP TABLE merge1;
DROP TABLE merge2;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册