提交 60abff33 编写于 作者: Z zhang2014

ISSUES-5436 support custom http [part 3]

上级 183eb82b
......@@ -12,13 +12,9 @@ namespace DB
{
class Context;
class CustomExecutor;
struct HTTPInputStreams;
struct HTTPOutputStreams;
using duration = std::chrono::steady_clock::duration;
using HTTPMatchExecutorPtr = std::shared_ptr<CustomExecutor>;
class CustomExecutor
{
public:
......@@ -38,4 +34,6 @@ public:
virtual QueryExecutors getQueryExecutor(Context & context, HTTPServerRequest & request, HTMLForm & params, const HTTPInputStreams & input_streams) const = 0;
};
using CustomExecutorPtr = std::shared_ptr<CustomExecutor>;
}
......@@ -6,6 +6,7 @@
#include <Interpreters/Context.h>
#include <Interpreters/executeQuery.h>
#include <Interpreters/CustomHTTP/HTTPOutputStreams.h>
#include "HTTPInputStreams.h"
namespace DB
......@@ -25,12 +26,20 @@ public:
QueryExecutors getQueryExecutor(Context & context, HTTPServerRequest & request, HTMLForm & params, const HTTPInputStreams & input_streams) const override
{
ReadBufferPtr in = prepareAndGetQueryInput(context, request, params, input_streams);
return {[&, shared_in = in](HTTPOutputStreams & output, HTTPServerResponse & response)
return {[&](HTTPOutputStreams & output, HTTPServerResponse & response)
{
const auto & execute_query = prepareQuery(context, params);
ReadBufferPtr execute_query_buf = std::make_shared<ReadBufferFromString>(execute_query);
ReadBufferPtr temp_query_buf;
if (!startsWith(request.getContentType().data(), "multipart/form-data"))
{
temp_query_buf = execute_query_buf; /// we create a temporary reference for not to be destroyed
execute_query_buf = std::make_unique<ConcatReadBuffer>(*temp_query_buf, *input_streams.in_maybe_internal_compressed);
}
executeQuery(
*shared_in, *output.out_maybe_delayed_and_compressed, /* allow_into_outfile = */ false, context,
*execute_query_buf, *output.out_maybe_delayed_and_compressed, /* allow_into_outfile = */ false, context,
[&response] (const String & content_type) { response.setContentType(content_type); },
[&response] (const String & current_query_id) { response.add("X-ClickHouse-Query-Id", current_query_id); }
);
......@@ -38,12 +47,21 @@ public:
}
private:
ReadBufferPtr prepareAndGetQueryInput(Context & context, HTTPServerRequest & request, HTMLForm & params, const HTTPInputStreams & input_streams) const
String prepareQuery(Context & context, HTMLForm & params) const
{
for (const auto & [key, value] : params)
{
const static size_t prefix_size = strlen("param_");
std::stringstream query_stream;
for (const auto & param : params)
{
if (param.first == "query")
query_stream << param.second;
else if (startsWith(param.first, "param_"))
context.setQueryParameter(param.first.substr(prefix_size), param.second);
}
query_stream << "\n";
return query_stream.str();
}
};
......
......@@ -12,7 +12,7 @@ namespace DB
class ExtractorContextChange
{
public:
ExtractorContextChange(Context & context_, const HTTPMatchExecutorPtr & executor_) : context(context_), executor(executor_) {}
ExtractorContextChange(Context & context_, const CustomExecutorPtr & executor_) : context(context_), executor(executor_) {}
void extract(Poco::Net::HTTPServerRequest & request, HTMLForm & params)
{
......@@ -91,7 +91,7 @@ public:
private:
Context & context;
const HTTPMatchExecutorPtr & executor;
const CustomExecutorPtr & executor;
};
......
......@@ -52,10 +52,17 @@ namespace
}
}
HTTPOutputStreams::HTTPOutputStreams(HTTPServerRequest & request, HTTPServerResponse & response, bool internal_compress, size_t keep_alive_timeout)
: out(createResponseOut(request, response, keep_alive_timeout))
, out_maybe_compressed(createMaybeCompressionOut(internal_compress, out))
, out_maybe_delayed_and_compressed(out_maybe_compressed)
{
}
HTTPOutputStreams::HTTPOutputStreams(
Context & context, HTTPServerRequest & request, HTTPServerResponse & response, HTMLForm & form, size_t keep_alive_timeout)
: out(createResponseOut(request, response, keep_alive_timeout))
, out_maybe_compressed(createMaybeCompressionOut(form, out))
, out_maybe_compressed(createMaybeCompressionOut(form.getParsed<bool>("compress", false), out))
, out_maybe_delayed_and_compressed(createMaybeDelayedAndCompressionOut(context, form, out_maybe_compressed))
{
Settings & settings = context.getSettingsRef();
......@@ -68,8 +75,7 @@ HTTPOutputStreams::HTTPOutputStreams(
out->setSendProgressInterval(settings.http_headers_progress_interval_ms);
/// Add CORS header if 'add_http_cors_header' setting is turned on and the client passed
/// Origin header.
/// Add CORS header if 'add_http_cors_header' setting is turned on and the client passed Origin header.
out->addHeaderCORS(settings.add_http_cors_header && !request.get("Origin", "").empty());
/// While still no data has been sent, we will report about query execution progress by sending HTTP headers.
......@@ -84,8 +90,7 @@ HTTPOutputStreams::HTTPOutputStreams(
}
}
std::shared_ptr<WriteBufferFromHTTPServerResponse> HTTPOutputStreams::createResponseOut(
HTTPServerRequest &request, HTTPServerResponse &response, size_t keep_alive_timeout)
HTTPResponseBufferPtr HTTPOutputStreams::createResponseOut(HTTPServerRequest & request, HTTPServerResponse & response, size_t keep_alive)
{
/// The client can pass a HTTP header indicating supported compression method (gzip or deflate).
String http_response_compression_methods = request.get("Accept-Encoding", "");
......@@ -96,27 +101,26 @@ std::shared_ptr<WriteBufferFromHTTPServerResponse> HTTPOutputStreams::createResp
/// NOTE parsing of the list of methods is slightly incorrect.
if (std::string::npos != http_response_compression_methods.find("gzip"))
return std::make_shared<WriteBufferFromHTTPServerResponse>(
request, response, keep_alive_timeout, true, CompressionMethod::Gzip, DBMS_DEFAULT_BUFFER_SIZE);
request, response, keep_alive, true, CompressionMethod::Gzip, DBMS_DEFAULT_BUFFER_SIZE, response.sent());
else if (std::string::npos != http_response_compression_methods.find("deflate"))
return std::make_shared<WriteBufferFromHTTPServerResponse>(
request, response, keep_alive_timeout, true, CompressionMethod::Zlib, DBMS_DEFAULT_BUFFER_SIZE);
request, response, keep_alive, true, CompressionMethod::Zlib, DBMS_DEFAULT_BUFFER_SIZE, response.sent());
#if USE_BROTLI
else if (http_response_compression_methods == "br")
return std::make_shared<WriteBufferFromHTTPServerResponse>(
request, response, keep_alive_timeout, true, CompressionMethod::Brotli, DBMS_DEFAULT_BUFFER_SIZE);
request, response, keep_alive, true, CompressionMethod::Brotli, DBMS_DEFAULT_BUFFER_SIZE, response.sent());
#endif
}
return std::make_shared<WriteBufferFromHTTPServerResponse>(
request, response, keep_alive_timeout, false, CompressionMethod{}, DBMS_DEFAULT_BUFFER_SIZE);
return std::make_shared<WriteBufferFromHTTPServerResponse>(request, response, keep_alive, false, CompressionMethod{}, DBMS_DEFAULT_BUFFER_SIZE, response.sent());
}
WriteBufferPtr HTTPOutputStreams::createMaybeCompressionOut(HTMLForm & form, std::shared_ptr<WriteBufferFromHTTPServerResponse> & out_)
WriteBufferPtr HTTPOutputStreams::createMaybeCompressionOut(bool compression, HTTPResponseBufferPtr & out_)
{
/// Client can pass a 'compress' flag in the query string. In this case the query result is
/// compressed using internal algorithm. This is not reflected in HTTP headers.
bool internal_compression = form.getParsed<bool>("compress", false);
return internal_compression ? std::make_shared<CompressedWriteBuffer>(*out_) : WriteBufferPtr(out_);
// bool internal_compression = form.getParsed<bool>("compress", false);
return compression ? std::make_shared<CompressedWriteBuffer>(*out_) : WriteBufferPtr(out_);
}
WriteBufferPtr HTTPOutputStreams::createMaybeDelayedAndCompressionOut(Context & context, HTMLForm & form, WriteBufferPtr & out_)
......@@ -172,6 +176,20 @@ WriteBufferPtr HTTPOutputStreams::createMaybeDelayedAndCompressionOut(Context &
return out_;
}
HTTPOutputStreams::~HTTPOutputStreams()
{
/// Destroy CascadeBuffer to actualize buffers' positions and reset extra references
if (out_maybe_delayed_and_compressed != out_maybe_compressed)
out_maybe_delayed_and_compressed.reset();
/// If buffer has data, and that data wasn't sent yet, then no need to send that data
if (out->count() == out->offset())
{
out_maybe_compressed->position() = out_maybe_compressed->buffer().begin();
out->position() = out->buffer().begin();
}
}
void HTTPOutputStreams::finalize() const
{
if (out_maybe_delayed_and_compressed != out_maybe_compressed)
......@@ -208,8 +226,9 @@ void HTTPOutputStreams::finalize() const
copyData(concat_read_buffer, *out_maybe_compressed);
}
/// Send HTTP headers with code 200 if no exception happened and the data is still not sent to
/// the client.
/// Send HTTP headers with code 200 if no exception happened and the data is still not sent to the client.
out_maybe_compressed->next();
out->next();
out->finalize();
}
......
......@@ -12,7 +12,7 @@ namespace DB
using HTTPServerRequest = Poco::Net::HTTPServerRequest;
using HTTPServerResponse = Poco::Net::HTTPServerResponse;
using HTTPResponseBufferPtr = std::shared_ptr<WriteBufferFromHTTPServerResponse>;
/* Raw data
* ↓
......@@ -24,25 +24,27 @@ using HTTPServerResponse = Poco::Net::HTTPServerResponse;
*/
struct HTTPOutputStreams
{
using HTTPResponseBufferPtr = std::shared_ptr<WriteBufferFromHTTPServerResponse>;
HTTPResponseBufferPtr out;
/// Points to 'out' or to CompressedWriteBuffer(*out), depending on settings.
std::shared_ptr<WriteBuffer> out_maybe_compressed;
/// Points to 'out' or to CompressedWriteBuffer(*out) or to CascadeWriteBuffer.
std::shared_ptr<WriteBuffer> out_maybe_delayed_and_compressed;
HTTPOutputStreams() = default;
HTTPOutputStreams(Context & context, HTTPServerRequest & request, HTTPServerResponse & response, HTMLForm & form, size_t keep_alive_timeout);
~HTTPOutputStreams();
void finalize() const;
WriteBufferPtr createMaybeDelayedAndCompressionOut(Context &context, HTMLForm &form, WriteBufferPtr &out_);
WriteBufferPtr createMaybeCompressionOut(HTMLForm & form, std::shared_ptr<WriteBufferFromHTTPServerResponse> & out_);
WriteBufferPtr createMaybeCompressionOut(bool compression, std::shared_ptr<WriteBufferFromHTTPServerResponse> & out_);
HTTPResponseBufferPtr createResponseOut(HTTPServerRequest & request, HTTPServerResponse & response, size_t keep_alive_timeout);
HTTPResponseBufferPtr createResponseOut(HTTPServerRequest & request, HTTPServerResponse & response, size_t keep_alive);
HTTPOutputStreams(HTTPServerRequest & request, HTTPServerResponse & response, bool internal_compress, size_t keep_alive_timeout);
HTTPOutputStreams(Context & context, HTTPServerRequest & request, HTTPServerResponse & response, HTMLForm & form, size_t keep_alive_timeout);
};
using HTTPOutputStreamsPtr = std::unique_ptr<HTTPOutputStreams>;
}
......@@ -192,7 +192,7 @@ HTTPHandler::SessionContextHolder::SessionContextHolder(IServer & accepted_serve
if (!session_id.empty())
{
session_timeout = parseSessionTimeout(accepted_server.config(), params);
session_context = context->acquireSession(session_id, session_timeout, params.check("session_check", "1"));
session_context = context->acquireSession(session_id, session_timeout, params.check<String>("session_check", "1"));
context = std::make_unique<Context>(*session_context);
context->setSessionContext(*session_context);
......@@ -241,31 +241,26 @@ void HTTPHandler::SessionContextHolder::authentication(HTTPServerRequest & reque
context->setCurrentQueryId(query_id);
}
void HTTPHandler::processQuery(HTTPRequest & request, HTMLForm & params, HTTPResponse & response, SessionContextHolder & holder)
void HTTPHandler::processQuery(Context & context, HTTPRequest & request, HTMLForm & params, HTTPResponse & response)
{
const auto & [name, custom_executor] = holder.context->getCustomExecutor(request/*, params*/);
const auto & name_with_custom_executor = context.getCustomExecutor(request/*, params*/);
LOG_TRACE(log, "Using " << name_with_custom_executor.first << " to execute URI: " << request.getURI());
LOG_TRACE(log, "Using " << name << " to execute URI: " << request.getURI());
ExtractorClientInfo{context.getClientInfo()}.extract(request);
ExtractorContextChange{context, name_with_custom_executor.second}.extract(request, params);
ExtractorClientInfo{holder.context->getClientInfo()}.extract(request);
ExtractorContextChange{*holder.context.get(), custom_executor}.extract(request, params);
HTTPInputStreams input_streams{context, request, params};
HTTPOutputStreams output_streams = HTTPOutputStreams(context, request, response, params, getKeepAliveTimeout());
auto & config = server.config();
HTTPInputStreams input_streams{*holder.context, request, params};
HTTPOutputStreams output_streams(*holder.context, request, response, params, config.getUInt("keep_alive_timeout", 10));
const auto & query_executors = custom_executor->getQueryExecutor(*holder.context, request, params, input_streams);
const auto & query_executors = name_with_custom_executor.second->getQueryExecutor(context, request, params, input_streams);
for (const auto & query_executor : query_executors)
query_executor(output_streams, response);
output_streams.finalize(); /// Send HTTP headers with code 200 if no exception happened and the data is still not sent to the client.
LOG_INFO(log, "Done processing query");
}
void HTTPHandler::trySendExceptionToClient(const std::string & message, int exception_code,
Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response,
HTTPOutputStreams & used_output)
void HTTPHandler::trySendExceptionToClient(
const std::string & message, int exception_code, Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response, bool compression)
{
try
{
......@@ -273,57 +268,25 @@ void HTTPHandler::trySendExceptionToClient(const std::string & message, int exce
/// If HTTP method is POST and Keep-Alive is turned on, we should read the whole request body
/// to avoid reading part of the current request body in the next request.
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST
&& response.getKeepAlive()
&& !request.stream().eof()
&& exception_code != ErrorCodes::HTTP_LENGTH_REQUIRED)
{
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST && response.getKeepAlive()
&& !request.stream().eof() && exception_code != ErrorCodes::HTTP_LENGTH_REQUIRED)
request.stream().ignore(std::numeric_limits<std::streamsize>::max());
}
bool auth_fail = exception_code == ErrorCodes::UNKNOWN_USER ||
exception_code == ErrorCodes::WRONG_PASSWORD ||
exception_code == ErrorCodes::REQUIRED_PASSWORD;
if (auth_fail)
if (exception_code == ErrorCodes::UNKNOWN_USER || exception_code == ErrorCodes::WRONG_PASSWORD ||
exception_code == ErrorCodes::REQUIRED_PASSWORD || exception_code != ErrorCodes::HTTP_LENGTH_REQUIRED)
{
response.requireAuthentication("ClickHouse server HTTP API");
response.send() << message << std::endl;
}
else
{
response.setStatusAndReason(exceptionCodeToHTTPStatus(exception_code));
}
HTTPOutputStreams output_streams(request, response, compression, getKeepAliveTimeout());
if (!response.sent() && !used_output.out_maybe_compressed)
{
/// If nothing was sent yet and we don't even know if we must compress the response.
response.send() << message << std::endl;
}
else if (used_output.out_maybe_compressed)
{
/// Destroy CascadeBuffer to actualize buffers' positions and reset extra references
if (used_output.out_maybe_delayed_and_compressed != used_output.out_maybe_compressed)
used_output.out_maybe_delayed_and_compressed.reset();
/// Send the error message into already used (and possibly compressed) stream.
/// Note that the error message will possibly be sent after some data.
/// Also HTTP code 200 could have already been sent.
/// If buffer has data, and that data wasn't sent yet, then no need to send that data
bool data_sent = used_output.out->count() != used_output.out->offset();
writeString(message, *output_streams.out_maybe_compressed);
writeChar('\n', *output_streams.out_maybe_compressed);
if (!data_sent)
{
used_output.out_maybe_compressed->position() = used_output.out_maybe_compressed->buffer().begin();
used_output.out->position() = used_output.out->buffer().begin();
}
writeString(message, *used_output.out_maybe_compressed);
writeChar('\n', *used_output.out_maybe_compressed);
used_output.out_maybe_compressed->next();
used_output.out->next();
used_output.out->finalize();
output_streams.finalize();
}
}
catch (...)
......@@ -337,10 +300,8 @@ void HTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & request, Poco::Ne
setThreadName("HTTPHandler");
ThreadStatus thread_status;
HTTPOutputStreams used_output;
/// In case of exception, send stack trace to client.
bool with_stacktrace = false;
bool with_stacktrace = false, internal_compression = false;
try
{
......@@ -353,6 +314,7 @@ void HTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & request, Poco::Ne
HTMLForm params(request);
with_stacktrace = params.getParsed<bool>("stacktrace", false);
internal_compression = params.getParsed<bool>("compress", false);
/// Workaround. Poco does not detect 411 Length Required case.
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST && !request.getChunkedTransferEncoding() && !request.hasContentLength())
......@@ -363,7 +325,8 @@ void HTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & request, Poco::Ne
CurrentThread::QueryScope query_scope(*holder.context);
holder.authentication(request, params);
processQuery(request, params, response, holder);
processQuery(*holder.context, request, params, response);
LOG_INFO(log, "Done processing query");
}
}
catch (...)
......@@ -375,8 +338,7 @@ void HTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & request, Poco::Ne
*/
int exception_code = getCurrentExceptionCode();
std::string exception_message = getCurrentExceptionMessage(with_stacktrace, true);
trySendExceptionToClient(exception_message, exception_code, request, response, HTTPOutputStreams{});
trySendExceptionToClient(exception_message, exception_code, request, response, internal_compression);
}
}
......
......@@ -56,11 +56,11 @@ private:
CurrentMetrics::Increment metric_increment{CurrentMetrics::HTTPConnection};
/// Also initializes 'used_output'.
void processQuery(HTTPRequest & request, HTMLForm & params, HTTPResponse & response, SessionContextHolder & holder);
size_t getKeepAliveTimeout() { return server.config().getUInt("keep_alive_timeout", 10); }
void trySendExceptionToClient(
const std::string & message, int exception_code, HTTPRequest & request, HTTPResponse & response, HTTPOutputStreams & used_output);
void processQuery(Context & context, HTTPRequest & request, HTMLForm & params, HTTPResponse & response);
void trySendExceptionToClient(const std::string & message, int exception_code, HTTPRequest & request, HTTPResponse & response, bool compression);
};
......
......@@ -153,13 +153,17 @@ WriteBufferFromHTTPServerResponse::WriteBufferFromHTTPServerResponse(
Poco::Net::HTTPServerResponse & response_,
unsigned keep_alive_timeout_,
bool compress_,
CompressionMethod compression_method_)
: BufferWithOwnMemory<WriteBuffer>(DBMS_DEFAULT_BUFFER_SIZE)
CompressionMethod compression_method_,
size_t size,
bool finish_send_headers_)
: BufferWithOwnMemory<WriteBuffer>(size)
, request(request_)
, response(response_)
, keep_alive_timeout(keep_alive_timeout_)
, compress(compress_)
, compression_method(compression_method_)
, headers_started_sending(finish_send_headers_)
, headers_finished_sending(finish_send_headers_)
{
}
......
......@@ -95,7 +95,9 @@ public:
Poco::Net::HTTPServerResponse & response_,
unsigned keep_alive_timeout_,
bool compress_ = false, /// If true - set Content-Encoding header and compress the result.
CompressionMethod compression_method_ = CompressionMethod::None);
CompressionMethod compression_method_ = CompressionMethod::Gzip,
size_t size = DBMS_DEFAULT_BUFFER_SIZE,
bool finish_send_headers_ = false);
/// Writes progess in repeating HTTP headers.
void onProgress(const Progress & progress);
......
......@@ -2044,9 +2044,9 @@ void Context::resetInputCallbacks()
input_blocks_reader = {};
}
std::pair<String, HTTPMatchExecutorPtr> Context::getCustomExecutor(Poco::Net::HTTPServerRequest & /*request*/)
std::pair<String, CustomExecutorPtr> Context::getCustomExecutor(Poco::Net::HTTPServerRequest & /*request*/)
{
return std::pair<String, HTTPMatchExecutorPtr>("Default", std::shared_ptr<CustomExecutorDefault>());
return std::pair<String, CustomExecutorPtr>("Default", std::make_shared<CustomExecutorDefault>());
}
......
......@@ -490,7 +490,7 @@ public:
Compiler & getCompiler();
std::pair<String, HTTPMatchExecutorPtr> getCustomExecutor(Poco::Net::HTTPServerRequest &request/*, HTMLForm & params*/);
std::pair<String, CustomExecutorPtr> getCustomExecutor(Poco::Net::HTTPServerRequest &request/*, HTMLForm & params*/);
/// Call after initialization before using system logs. Call for global context.
void initializeSystemLogs();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册