提交 64d49e45 编写于 作者: V Vladimir Chebotarev

Minor review fixes.

上级 0ee6f623
......@@ -5,11 +5,11 @@
#include <common/logger_useful.h>
#define DEFAULT_S3_MAX_FOLLOW_GET_REDIRECT 2
namespace DB
{
const int DEFAULT_S3_MAX_FOLLOW_GET_REDIRECT = 2;
ReadBufferFromS3::ReadBufferFromS3(Poco::URI uri_,
const ConnectionTimeouts & timeouts,
const Poco::Net::HTTPBasicCredentials & credentials,
......
......@@ -11,12 +11,13 @@
#include <common/logger_useful.h>
#define DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT 2
#define S3_SOFT_MAX_PARTS 10000
namespace DB
{
const int DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT = 2;
const int S3_WARN_MAX_PARTS = 10000;
namespace ErrorCodes
{
extern const int INCORRECT_DATA;
......@@ -92,34 +93,33 @@ void WriteBufferFromS3::initiate()
{
// See https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadInitiate.html
Poco::Net::HTTPResponse response;
std::unique_ptr<Poco::Net::HTTPRequest> request;
std::unique_ptr<Poco::Net::HTTPRequest> request_ptr;
HTTPSessionPtr session;
std::istream * istr = nullptr; /// owned by session
Poco::URI initiate_uri = uri;
initiate_uri.setRawQuery("uploads");
auto params = uri.getQueryParameters();
for (auto it = params.begin(); it != params.end(); ++it)
for (auto & param: uri.getQueryParameters())
{
initiate_uri.addQueryParameter(it->first, it->second);
initiate_uri.addQueryParameter(param.first, param.second);
}
for (int i = 0; i < DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT; ++i)
{
session = makeHTTPSession(initiate_uri, timeouts);
request = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_POST, initiate_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
request->setHost(initiate_uri.getHost()); // use original, not resolved host name in header
request_ptr = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_POST, initiate_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
request_ptr->setHost(initiate_uri.getHost()); // use original, not resolved host name in header
if (auth_request.hasCredentials())
{
Poco::Net::HTTPBasicCredentials credentials(auth_request);
credentials.authenticate(*request);
credentials.authenticate(*request_ptr);
}
request->setContentLength(0);
request_ptr->setContentLength(0);
LOG_TRACE((&Logger::get("WriteBufferFromS3")), "Sending request to " << initiate_uri.toString());
session->sendRequest(*request);
session->sendRequest(*request_ptr);
istr = &session->receiveResponse(response);
......@@ -134,7 +134,7 @@ void WriteBufferFromS3::initiate()
initiate_uri = location_iterator->second;
}
assertResponseIsOk(*request, response, *istr);
assertResponseIsOk(*request_ptr, response, *istr);
Poco::XML::InputSource src(*istr);
Poco::XML::DOMParser parser;
......@@ -156,37 +156,38 @@ void WriteBufferFromS3::writePart(const String & data)
{
// See https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadUploadPart.html
Poco::Net::HTTPResponse response;
std::unique_ptr<Poco::Net::HTTPRequest> request;
std::unique_ptr<Poco::Net::HTTPRequest> request_ptr;
HTTPSessionPtr session;
std::istream * istr = nullptr; /// owned by session
Poco::URI part_uri = uri;
part_uri.addQueryParameter("partNumber", std::to_string(part_tags.size() + 1));
part_uri.addQueryParameter("uploadId", upload_id);
if (part_tags.size() == S3_SOFT_MAX_PARTS)
if (part_tags.size() == S3_WARN_MAX_PARTS)
{
// Don't throw exception here by ourselves but leave the decision to take by S3 server.
LOG_WARNING(&Logger::get("WriteBufferFromS3"), "Maximum part number in S3 protocol has reached (too much parts). Server may not accept this whole upload.");
}
for (int i = 0; i < DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT; ++i)
{
session = makeHTTPSession(part_uri, timeouts);
request = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_PUT, part_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
request->setHost(part_uri.getHost()); // use original, not resolved host name in header
request_ptr = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_PUT, part_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
request_ptr->setHost(part_uri.getHost()); // use original, not resolved host name in header
if (auth_request.hasCredentials())
{
Poco::Net::HTTPBasicCredentials credentials(auth_request);
credentials.authenticate(*request);
credentials.authenticate(*request_ptr);
}
request->setExpectContinue(true);
request_ptr->setExpectContinue(true);
request->setContentLength(data.size());
request_ptr->setContentLength(data.size());
LOG_TRACE((&Logger::get("WriteBufferFromS3")), "Sending request to " << part_uri.toString());
std::ostream & ostr = session->sendRequest(*request);
std::ostream & ostr = session->sendRequest(*request_ptr);
if (session->peekResponse(response))
{
// Received 100-continue.
......@@ -206,7 +207,7 @@ void WriteBufferFromS3::writePart(const String & data)
part_uri = location_iterator->second;
}
assertResponseIsOk(*request, response, *istr);
assertResponseIsOk(*request_ptr, response, *istr);
auto etag_iterator = response.find("ETag");
if (etag_iterator == response.end())
......@@ -221,7 +222,7 @@ void WriteBufferFromS3::complete()
{
// See https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadComplete.html
Poco::Net::HTTPResponse response;
std::unique_ptr<Poco::Net::HTTPRequest> request;
std::unique_ptr<Poco::Net::HTTPRequest> request_ptr;
HTTPSessionPtr session;
std::istream * istr = nullptr; /// owned by session
Poco::URI complete_uri = uri;
......@@ -244,22 +245,22 @@ void WriteBufferFromS3::complete()
for (int i = 0; i < DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT; ++i)
{
session = makeHTTPSession(complete_uri, timeouts);
request = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_POST, complete_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
request->setHost(complete_uri.getHost()); // use original, not resolved host name in header
request_ptr = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_POST, complete_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
request_ptr->setHost(complete_uri.getHost()); // use original, not resolved host name in header
if (auth_request.hasCredentials())
{
Poco::Net::HTTPBasicCredentials credentials(auth_request);
credentials.authenticate(*request);
credentials.authenticate(*request_ptr);
}
request->setExpectContinue(true);
request_ptr->setExpectContinue(true);
request->setContentLength(data.size());
request_ptr->setContentLength(data.size());
LOG_TRACE((&Logger::get("WriteBufferFromS3")), "Sending request to " << complete_uri.toString());
std::ostream & ostr = session->sendRequest(*request);
std::ostream & ostr = session->sendRequest(*request_ptr);
if (session->peekResponse(response))
{
// Received 100-continue.
......@@ -279,7 +280,7 @@ void WriteBufferFromS3::complete()
complete_uri = location_iterator->second;
}
assertResponseIsOk(*request, response, *istr);
assertResponseIsOk(*request_ptr, response, *istr);
}
}
......@@ -15,7 +15,7 @@ logging.getLogger().addHandler(logging.StreamHandler())
def get_communication_data(started_cluster):
conn = httplib.HTTPConnection(started_cluster.instances['dummy'].ip_address, started_cluster.communication_port)
conn = httplib.HTTPConnection(started_cluster.instances["dummy"].ip_address, started_cluster.communication_port)
conn.request("GET", "/")
r = conn.getresponse()
raw_data = r.read()
......@@ -24,7 +24,7 @@ def get_communication_data(started_cluster):
def put_communication_data(started_cluster, body):
conn = httplib.HTTPConnection(started_cluster.instances['dummy'].ip_address, started_cluster.communication_port)
conn = httplib.HTTPConnection(started_cluster.instances["dummy"].ip_address, started_cluster.communication_port)
conn.request("PUT", "/", body)
r = conn.getresponse()
conn.close()
......@@ -34,29 +34,29 @@ def put_communication_data(started_cluster, body):
def started_cluster():
try:
cluster = ClickHouseCluster(__file__)
instance = cluster.add_instance('dummy', config_dir="configs", main_configs=['configs/min_chunk_size.xml'])
instance = cluster.add_instance("dummy", config_dir="configs", main_configs=["configs/min_chunk_size.xml"])
cluster.start()
cluster.communication_port = 10000
instance.copy_file_to_container(os.path.join(os.path.dirname(__file__), 'test_server.py'), 'test_server.py')
cluster.bucket = 'abc'
instance.exec_in_container(['python', 'test_server.py', str(cluster.communication_port), cluster.bucket], detach=True)
instance.copy_file_to_container(os.path.join(os.path.dirname(__file__), "test_server.py"), "test_server.py")
cluster.bucket = "abc"
instance.exec_in_container(["python", "test_server.py", str(cluster.communication_port), cluster.bucket], detach=True)
cluster.mock_host = instance.ip_address
for i in range(10):
try:
data = get_communication_data(cluster)
cluster.redirecting_to_http_port = data['redirecting_to_http_port']
cluster.preserving_data_port = data['preserving_data_port']
cluster.multipart_preserving_data_port = data['multipart_preserving_data_port']
cluster.redirecting_preserving_data_port = data['redirecting_preserving_data_port']
cluster.redirecting_to_http_port = data["redirecting_to_http_port"]
cluster.preserving_data_port = data["preserving_data_port"]
cluster.multipart_preserving_data_port = data["multipart_preserving_data_port"]
cluster.redirecting_preserving_data_port = data["redirecting_preserving_data_port"]
except:
logging.error(traceback.format_exc())
time.sleep(0.5)
else:
break
else:
assert False, 'Could not initialize mock server'
assert False, "Could not initialize mock server"
yield cluster
......@@ -65,92 +65,97 @@ def started_cluster():
def run_query(instance, query, stdin=None):
logging.info('Running query "{}"...'.format(query))
logging.info("Running query '{}'...".format(query))
result = instance.query(query, stdin=stdin)
logging.info('Query finished')
logging.info("Query finished")
return result
def test_get_with_redirect(started_cluster):
instance = started_cluster.instances['dummy']
format = 'column1 UInt32, column2 UInt32, column3 UInt32'
instance = started_cluster.instances["dummy"]
format = "column1 UInt32, column2 UInt32, column3 UInt32"
put_communication_data(started_cluster, '=== Get with redirect test ===')
put_communication_data(started_cluster, "=== Get with redirect test ===")
query = "select *, column1*column2*column3 from s3('http://{}:{}/', 'CSV', '{}')".format(started_cluster.mock_host, started_cluster.redirecting_to_http_port, format)
stdout = run_query(instance, query)
assert list(map(str.split, stdout.splitlines())) == [
['42', '87', '44', '160776'],
['55', '33', '81', '147015'],
['1', '0', '9', '0'],
["42", "87", "44", "160776"],
["55", "33", "81", "147015"],
["1", "0", "9", "0"],
]
def test_put(started_cluster):
instance = started_cluster.instances['dummy']
format = 'column1 UInt32, column2 UInt32, column3 UInt32'
instance = started_cluster.instances["dummy"]
format = "column1 UInt32, column2 UInt32, column3 UInt32"
logging.info('Phase 3')
put_communication_data(started_cluster, '=== Put test ===')
values = '(1, 2, 3), (3, 2, 1), (78, 43, 45)'
logging.info("Phase 3")
put_communication_data(started_cluster, "=== Put test ===")
values = "(1, 2, 3), (3, 2, 1), (78, 43, 45)"
put_query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') values {}".format(started_cluster.mock_host, started_cluster.preserving_data_port, started_cluster.bucket, format, values)
run_query(instance, put_query)
data = get_communication_data(started_cluster)
received_data_completed = data['received_data_completed']
received_data = data['received_data']
finalize_data = data['finalize_data']
finalize_data_query = data['finalize_data_query']
assert received_data[-1].decode() == '1,2,3\n3,2,1\n78,43,45\n'
received_data_completed = data["received_data_completed"]
received_data = data["received_data"]
finalize_data = data["finalize_data"]
finalize_data_query = data["finalize_data_query"]
assert received_data[-1].decode() == "1,2,3\n3,2,1\n78,43,45\n"
assert received_data_completed
assert finalize_data == '<CompleteMultipartUpload><Part><PartNumber>1</PartNumber><ETag>hello-etag</ETag></Part></CompleteMultipartUpload>'
assert finalize_data_query == 'uploadId=TEST'
assert finalize_data == "<CompleteMultipartUpload><Part><PartNumber>1</PartNumber><ETag>hello-etag</ETag></Part></CompleteMultipartUpload>"
assert finalize_data_query == "uploadId=TEST"
def test_put_csv(started_cluster):
instance = started_cluster.instances['dummy']
format = 'column1 UInt32, column2 UInt32, column3 UInt32'
instance = started_cluster.instances["dummy"]
format = "column1 UInt32, column2 UInt32, column3 UInt32"
put_communication_data(started_cluster, '=== Put test CSV ===')
put_communication_data(started_cluster, "=== Put test CSV ===")
put_query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') format CSV".format(started_cluster.mock_host, started_cluster.preserving_data_port, started_cluster.bucket, format)
csv_data = '8,9,16\n11,18,13\n22,14,2\n'
csv_data = "8,9,16\n11,18,13\n22,14,2\n"
run_query(instance, put_query, stdin=csv_data)
data = get_communication_data(started_cluster)
received_data_completed = data['received_data_completed']
received_data = data['received_data']
finalize_data = data['finalize_data']
finalize_data_query = data['finalize_data_query']
received_data_completed = data["received_data_completed"]
received_data = data["received_data"]
finalize_data = data["finalize_data"]
finalize_data_query = data["finalize_data_query"]
assert received_data[-1].decode() == csv_data
assert received_data_completed
assert finalize_data == '<CompleteMultipartUpload><Part><PartNumber>1</PartNumber><ETag>hello-etag</ETag></Part></CompleteMultipartUpload>'
assert finalize_data_query == 'uploadId=TEST'
assert finalize_data == "<CompleteMultipartUpload><Part><PartNumber>1</PartNumber><ETag>hello-etag</ETag></Part></CompleteMultipartUpload>"
assert finalize_data_query == "uploadId=TEST"
def test_put_with_redirect(started_cluster):
instance = started_cluster.instances['dummy']
format = 'column1 UInt32, column2 UInt32, column3 UInt32'
instance = started_cluster.instances["dummy"]
format = "column1 UInt32, column2 UInt32, column3 UInt32"
put_communication_data(started_cluster, '=== Put with redirect test ===')
other_values = '(1, 1, 1), (1, 1, 1), (11, 11, 11)'
put_communication_data(started_cluster, "=== Put with redirect test ===")
other_values = "(1, 1, 1), (1, 1, 1), (11, 11, 11)"
query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') values {}".format(started_cluster.mock_host, started_cluster.redirecting_preserving_data_port, started_cluster.bucket, format, other_values)
run_query(instance, query)
query = "select *, column1*column2*column3 from s3('http://{}:{}/{}/test.csv', 'CSV', '{}')".format(started_cluster.mock_host, started_cluster.preserving_data_port, started_cluster.bucket, format)
stdout = run_query(instance, query)
assert list(map(str.split, stdout.splitlines())) == [
['1', '1', '1', '1'],
['1', '1', '1', '1'],
['11', '11', '11', '1331'],
["1", "1", "1", "1"],
["1", "1", "1", "1"],
["11", "11", "11", "1331"],
]
data = get_communication_data(started_cluster)
received_data = data['received_data']
assert received_data[-1].decode() == '1,1,1\n1,1,1\n11,11,11\n'
received_data = data["received_data"]
assert received_data[-1].decode() == "1,1,1\n1,1,1\n11,11,11\n"
def test_multipart_put(started_cluster):
instance = started_cluster.instances['dummy']
format = 'column1 UInt32, column2 UInt32, column3 UInt32'
instance = started_cluster.instances["dummy"]
format = "column1 UInt32, column2 UInt32, column3 UInt32"
put_communication_data(started_cluster, '=== Multipart test ===')
put_communication_data(started_cluster, "=== Multipart test ===")
long_data = [[i, i+1, i+2] for i in range(100000)]
long_values = ''.join([ '{},{},{}\n'.format(x,y,z) for x, y, z in long_data ])
long_values = "".join([ "{},{},{}\n".format(x,y,z) for x, y, z in long_data ])
put_query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') format CSV".format(started_cluster.mock_host, started_cluster.multipart_preserving_data_port, started_cluster.bucket, format)
run_query(instance, put_query, stdin=long_values)
data = get_communication_data(started_cluster)
assert 'multipart_received_data' in data
received_data = data['multipart_received_data']
assert received_data[-1].decode() == ''.join([ '{},{},{}\n'.format(x, y, z) for x, y, z in long_data ])
assert 1 < data['multipart_parts'] < 10000
assert "multipart_received_data" in data
received_data = data["multipart_received_data"]
assert received_data[-1].decode() == "".join([ "{},{},{}\n".format(x, y, z) for x, y, z in long_data ])
assert 1 < data["multipart_parts"] < 10000
......@@ -25,8 +25,8 @@ import xml.etree.ElementTree
logging.getLogger().setLevel(logging.INFO)
file_handler = logging.FileHandler('/var/log/clickhouse-server/test-server.log', 'a', encoding='utf-8')
file_handler.setFormatter(logging.Formatter('%(asctime)s %(message)s'))
file_handler = logging.FileHandler("/var/log/clickhouse-server/test-server.log", "a", encoding="utf-8")
file_handler.setFormatter(logging.Formatter("%(asctime)s %(message)s"))
logging.getLogger().addHandler(file_handler)
logging.getLogger().addHandler(logging.StreamHandler())
......@@ -54,21 +54,21 @@ def GetFreeTCPPortsAndIP(n):
), localhost = GetFreeTCPPortsAndIP(5)
data = {
'redirecting_to_http_port': redirecting_to_http_port,
'preserving_data_port': preserving_data_port,
'multipart_preserving_data_port': multipart_preserving_data_port,
'redirecting_preserving_data_port': redirecting_preserving_data_port,
"redirecting_to_http_port": redirecting_to_http_port,
"preserving_data_port": preserving_data_port,
"multipart_preserving_data_port": multipart_preserving_data_port,
"redirecting_preserving_data_port": redirecting_preserving_data_port,
}
class SimpleHTTPServerHandler(BaseHTTPRequestHandler):
def do_GET(self):
logging.info('GET {}'.format(self.path))
if self.path == '/milovidov/test.csv':
logging.info("GET {}".format(self.path))
if self.path == "/milovidov/test.csv":
self.send_response(200)
self.send_header('Content-type', 'text/plain')
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write('42,87,44\n55,33,81\n1,0,9\n')
self.wfile.write("42,87,44\n55,33,81\n1,0,9\n")
else:
self.send_response(404)
self.end_headers()
......@@ -78,27 +78,27 @@ class SimpleHTTPServerHandler(BaseHTTPRequestHandler):
class RedirectingToHTTPHandler(BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(307)
self.send_header('Content-type', 'text/xml')
self.send_header('Location', 'http://{}:{}/milovidov/test.csv'.format(localhost, simple_server_port))
self.send_header("Content-type", "text/xml")
self.send_header("Location", "http://{}:{}/milovidov/test.csv".format(localhost, simple_server_port))
self.end_headers()
self.wfile.write(r'''<?xml version="1.0" encoding="UTF-8"?>
self.wfile.write(r"""<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>TemporaryRedirect</Code>
<Message>Please re-send this request to the specified temporary endpoint.
Continue to use the original request endpoint for future requests.</Message>
<Endpoint>storage.yandexcloud.net</Endpoint>
</Error>'''.encode())
</Error>""".encode())
self.finish()
class PreservingDataHandler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
protocol_version = "HTTP/1.1"
def parse_request(self):
result = BaseHTTPRequestHandler.parse_request(self)
# Adaptation to Python 3.
if sys.version_info.major == 2 and result == True:
expect = self.headers.get('Expect', "")
expect = self.headers.get("Expect", "")
if (expect.lower() == "100-continue" and self.protocol_version >= "HTTP/1.1" and self.request_version >= "HTTP/1.1"):
if not self.handle_expect_100():
return False
......@@ -109,12 +109,12 @@ class PreservingDataHandler(BaseHTTPRequestHandler):
if code in self.responses:
message = self.responses[code][0]
else:
message = ''
if self.request_version != 'HTTP/0.9':
message = ""
if self.request_version != "HTTP/0.9":
self.wfile.write("%s %d %s\r\n" % (self.protocol_version, code, message))
def handle_expect_100(self):
logging.info('Received Expect-100')
logging.info("Received Expect-100")
self.send_response_only(100)
self.end_headers()
return True
......@@ -122,37 +122,37 @@ class PreservingDataHandler(BaseHTTPRequestHandler):
def do_POST(self):
self.send_response(200)
query = urlparse.urlparse(self.path).query
logging.info('PreservingDataHandler POST ?' + query)
if query == 'uploads':
post_data = r'''<?xml version="1.0" encoding="UTF-8"?>
<hi><UploadId>TEST</UploadId></hi>'''.encode()
self.send_header('Content-length', str(len(post_data)))
self.send_header('Content-type', 'text/plain')
logging.info("PreservingDataHandler POST ?" + query)
if query == "uploads":
post_data = r"""<?xml version="1.0" encoding="UTF-8"?>
<hi><UploadId>TEST</UploadId></hi>""".encode()
self.send_header("Content-length", str(len(post_data)))
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(post_data)
else:
post_data = self.rfile.read(int(self.headers.get('Content-Length')))
self.send_header('Content-type', 'text/plain')
post_data = self.rfile.read(int(self.headers.get("Content-Length")))
self.send_header("Content-type", "text/plain")
self.end_headers()
data['received_data_completed'] = True
data['finalize_data'] = post_data
data['finalize_data_query'] = query
data["received_data_completed"] = True
data["finalize_data"] = post_data
data["finalize_data_query"] = query
self.finish()
def do_PUT(self):
self.send_response(200)
self.send_header('Content-type', 'text/plain')
self.send_header('ETag', 'hello-etag')
self.send_header("Content-type", "text/plain")
self.send_header("ETag", "hello-etag")
self.end_headers()
query = urlparse.urlparse(self.path).query
path = urlparse.urlparse(self.path).path
logging.info('Content-Length = ' + self.headers.get('Content-Length'))
logging.info('PUT ' + query)
assert self.headers.get('Content-Length')
assert self.headers['Expect'] == '100-continue'
logging.info("Content-Length = " + self.headers.get("Content-Length"))
logging.info("PUT " + query)
assert self.headers.get("Content-Length")
assert self.headers["Expect"] == "100-continue"
put_data = self.rfile.read()
data.setdefault('received_data', []).append(put_data)
logging.info('PUT to {}'.format(path))
data.setdefault("received_data", []).append(put_data)
logging.info("PUT to {}".format(path))
self.server.storage[path] = put_data
self.finish()
......@@ -160,8 +160,8 @@ class PreservingDataHandler(BaseHTTPRequestHandler):
path = urlparse.urlparse(self.path).path
if path in self.server.storage:
self.send_response(200)
self.send_header('Content-type', 'text/plain')
self.send_header('Content-length', str(len(self.server.storage[path])))
self.send_header("Content-type", "text/plain")
self.send_header("Content-length", str(len(self.server.storage[path])))
self.end_headers()
self.wfile.write(self.server.storage[path])
else:
......@@ -171,13 +171,13 @@ class PreservingDataHandler(BaseHTTPRequestHandler):
class MultipartPreservingDataHandler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
protocol_version = "HTTP/1.1"
def parse_request(self):
result = BaseHTTPRequestHandler.parse_request(self)
# Adaptation to Python 3.
if sys.version_info.major == 2 and result == True:
expect = self.headers.get('Expect', "")
expect = self.headers.get("Expect", "")
if (expect.lower() == "100-continue" and self.protocol_version >= "HTTP/1.1" and self.request_version >= "HTTP/1.1"):
if not self.handle_expect_100():
return False
......@@ -188,78 +188,78 @@ class MultipartPreservingDataHandler(BaseHTTPRequestHandler):
if code in self.responses:
message = self.responses[code][0]
else:
message = ''
if self.request_version != 'HTTP/0.9':
message = ""
if self.request_version != "HTTP/0.9":
self.wfile.write("%s %d %s\r\n" % (self.protocol_version, code, message))
def handle_expect_100(self):
logging.info('Received Expect-100')
logging.info("Received Expect-100")
self.send_response_only(100)
self.end_headers()
return True
def do_POST(self):
query = urlparse.urlparse(self.path).query
logging.info('MultipartPreservingDataHandler POST ?' + query)
if query == 'uploads':
logging.info("MultipartPreservingDataHandler POST ?" + query)
if query == "uploads":
self.send_response(200)
post_data = r'''<?xml version="1.0" encoding="UTF-8"?>
<hi><UploadId>TEST</UploadId></hi>'''.encode()
self.send_header('Content-length', str(len(post_data)))
self.send_header('Content-type', 'text/plain')
post_data = r"""<?xml version="1.0" encoding="UTF-8"?>
<hi><UploadId>TEST</UploadId></hi>""".encode()
self.send_header("Content-length", str(len(post_data)))
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(post_data)
else:
try:
assert query == 'uploadId=TEST'
logging.info('Content-Length = ' + self.headers.get('Content-Length'))
post_data = self.rfile.read(int(self.headers.get('Content-Length')))
assert query == "uploadId=TEST"
logging.info("Content-Length = " + self.headers.get("Content-Length"))
post_data = self.rfile.read(int(self.headers.get("Content-Length")))
root = xml.etree.ElementTree.fromstring(post_data)
assert root.tag == 'CompleteMultipartUpload'
assert root.tag == "CompleteMultipartUpload"
assert len(root) > 1
content = ''
content = ""
for i, part in enumerate(root):
assert part.tag == 'Part'
assert part.tag == "Part"
assert len(part) == 2
assert part[0].tag == 'PartNumber'
assert part[1].tag == 'ETag'
assert part[0].tag == "PartNumber"
assert part[1].tag == "ETag"
assert int(part[0].text) == i + 1
content += self.server.storage['@'+part[1].text]
data.setdefault('multipart_received_data', []).append(content)
data['multipart_parts'] = len(root)
content += self.server.storage["@"+part[1].text]
data.setdefault("multipart_received_data", []).append(content)
data["multipart_parts"] = len(root)
self.send_response(200)
self.send_header('Content-type', 'text/plain')
self.send_header("Content-type", "text/plain")
self.end_headers()
logging.info('Sending 200')
logging.info("Sending 200")
except:
logging.error('Sending 500')
logging.error("Sending 500")
self.send_response(500)
self.finish()
def do_PUT(self):
uid = uuid.uuid4()
self.send_response(200)
self.send_header('Content-type', 'text/plain')
self.send_header('ETag', str(uid))
self.send_header("Content-type", "text/plain")
self.send_header("ETag", str(uid))
self.end_headers()
query = urlparse.urlparse(self.path).query
path = urlparse.urlparse(self.path).path
logging.info('Content-Length = ' + self.headers.get('Content-Length'))
logging.info('PUT ' + query)
assert self.headers.get('Content-Length')
assert self.headers['Expect'] == '100-continue'
logging.info("Content-Length = " + self.headers.get("Content-Length"))
logging.info("PUT " + query)
assert self.headers.get("Content-Length")
assert self.headers["Expect"] == "100-continue"
put_data = self.rfile.read()
data.setdefault('received_data', []).append(put_data)
logging.info('PUT to {}'.format(path))
self.server.storage['@'+str(uid)] = put_data
data.setdefault("received_data", []).append(put_data)
logging.info("PUT to {}".format(path))
self.server.storage["@"+str(uid)] = put_data
self.finish()
def do_GET(self):
path = urlparse.urlparse(self.path).path
if path in self.server.storage:
self.send_response(200)
self.send_header('Content-type', 'text/plain')
self.send_header('Content-length', str(len(self.server.storage[path])))
self.send_header("Content-type", "text/plain")
self.send_header("Content-length", str(len(self.server.storage[path])))
self.end_headers()
self.wfile.write(self.server.storage[path])
else:
......@@ -269,13 +269,13 @@ class MultipartPreservingDataHandler(BaseHTTPRequestHandler):
class RedirectingPreservingDataHandler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
protocol_version = "HTTP/1.1"
def parse_request(self):
result = BaseHTTPRequestHandler.parse_request(self)
# Adaptation to Python 3.
if sys.version_info.major == 2 and result == True:
expect = self.headers.get('Expect', "")
expect = self.headers.get("Expect", "")
if (expect.lower() == "100-continue" and self.protocol_version >= "HTTP/1.1" and self.request_version >= "HTTP/1.1"):
if not self.handle_expect_100():
return False
......@@ -286,46 +286,46 @@ class RedirectingPreservingDataHandler(BaseHTTPRequestHandler):
if code in self.responses:
message = self.responses[code][0]
else:
message = ''
if self.request_version != 'HTTP/0.9':
message = ""
if self.request_version != "HTTP/0.9":
self.wfile.write("%s %d %s\r\n" % (self.protocol_version, code, message))
def handle_expect_100(self):
logging.info('Received Expect-100')
logging.info("Received Expect-100")
return True
def do_POST(self):
query = urlparse.urlparse(self.path).query
if query:
query = '?{}'.format(query)
query = "?{}".format(query)
self.send_response(307)
self.send_header('Content-type', 'text/xml')
self.send_header('Location', 'http://{host}:{port}/{bucket}/test.csv{query}'.format(host=localhost, port=preserving_data_port, bucket=bucket, query=query))
self.send_header("Content-type", "text/xml")
self.send_header("Location", "http://{host}:{port}/{bucket}/test.csv{query}".format(host=localhost, port=preserving_data_port, bucket=bucket, query=query))
self.end_headers()
self.wfile.write(r'''<?xml version="1.0" encoding="UTF-8"?>
self.wfile.write(r"""<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>TemporaryRedirect</Code>
<Message>Please re-send this request to the specified temporary endpoint.
Continue to use the original request endpoint for future requests.</Message>
<Endpoint>{host}:{port}</Endpoint>
</Error>'''.format(host=localhost, port=preserving_data_port).encode())
</Error>""".format(host=localhost, port=preserving_data_port).encode())
self.finish()
def do_PUT(self):
query = urlparse.urlparse(self.path).query
if query:
query = '?{}'.format(query)
query = "?{}".format(query)
self.send_response(307)
self.send_header('Content-type', 'text/xml')
self.send_header('Location', 'http://{host}:{port}/{bucket}/test.csv{query}'.format(host=localhost, port=preserving_data_port, bucket=bucket, query=query))
self.send_header("Content-type", "text/xml")
self.send_header("Location", "http://{host}:{port}/{bucket}/test.csv{query}".format(host=localhost, port=preserving_data_port, bucket=bucket, query=query))
self.end_headers()
self.wfile.write(r'''<?xml version="1.0" encoding="UTF-8"?>
self.wfile.write(r"""<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>TemporaryRedirect</Code>
<Message>Please re-send this request to the specified temporary endpoint.
Continue to use the original request endpoint for future requests.</Message>
<Endpoint>{host}:{port}</Endpoint>
</Error>'''.format(host=localhost, port=preserving_data_port).encode())
</Error>""".format(host=localhost, port=preserving_data_port).encode())
self.finish()
......@@ -357,8 +357,8 @@ jobs = [ threading.Thread(target=server.serve_forever) for server in servers ]
time.sleep(60) # Timeout
logging.info('Shutting down')
logging.info("Shutting down")
[ server.shutdown() for server in servers ]
logging.info('Joining threads')
logging.info("Joining threads")
[ job.join() for job in jobs ]
logging.info('Done')
logging.info("Done")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册