httpserver_test.py 49.2 KB
Newer Older
1
from tornado import gen, netutil
B
Ben Darnell 已提交
2 3 4 5 6 7 8 9
from tornado.escape import (
    json_decode,
    json_encode,
    utf8,
    _unicode,
    recursive_unicode,
    native_str,
)
10
from tornado.http1connection import HTTP1Connection
11
from tornado.httpclient import HTTPError
B
Ben Darnell 已提交
12
from tornado.httpserver import HTTPServer
B
Ben Darnell 已提交
13 14 15 16 17
from tornado.httputil import (
    HTTPHeaders,
    HTTPMessageDelegate,
    HTTPServerConnectionDelegate,
    ResponseStartLine,
18
)
B
Ben Darnell 已提交
19
from tornado.iostream import IOStream
20
from tornado.locks import Event
21
from tornado.log import gen_log, app_log
B
Ben Darnell 已提交
22
from tornado.netutil import ssl_options_to_context
23
from tornado.simple_httpclient import SimpleAsyncHTTPClient
B
Ben Darnell 已提交
24 25 26 27 28 29
from tornado.testing import (
    AsyncHTTPTestCase,
    AsyncHTTPSTestCase,
    AsyncTestCase,
    ExpectLog,
    gen_test,
30
)
31
from tornado.test.util import skipOnTravis
32
from tornado.web import Application, RequestHandler, stream_request_body
33

34
from contextlib import closing
35
import datetime
36
import gzip
37
import logging
38
import os
B
Ben Darnell 已提交
39 40
import shutil
import socket
41
import ssl
42
import sys
B
Ben Darnell 已提交
43
import tempfile
44
import textwrap
45
import unittest
46
import urllib.parse
M
Mikhail Korobov 已提交
47
from io import BytesIO
48

49
import typing
B
Ben Darnell 已提交
50

51 52 53
if typing.TYPE_CHECKING:
    from typing import Dict, List  # noqa: F401

54

55 56
async def read_stream_body(stream):
    """Reads an HTTP response from `stream` and returns a tuple of its
57
    start_line, headers and body."""
58
    chunks = []
B
Ben Darnell 已提交
59

60
    class Delegate(HTTPMessageDelegate):
61 62
        def headers_received(self, start_line, headers):
            self.headers = headers
63
            self.start_line = start_line
64

65 66 67 68
        def data_received(self, chunk):
            chunks.append(chunk)

        def finish(self):
69
            conn.detach()  # type: ignore
B
Ben Darnell 已提交
70

71
    conn = HTTP1Connection(stream, True)
72 73 74
    delegate = Delegate()
    await conn.read_response(delegate)
    return delegate.start_line, delegate.headers, b"".join(chunks)
75 76


77
class HandlerBaseTestCase(AsyncHTTPTestCase):
B
Ben Darnell 已提交
78 79
    Handler = None

80
    def get_app(self):
B
Ben Darnell 已提交
81
        return Application([("/", self.__class__.Handler)])
82 83 84 85 86 87

    def fetch_json(self, *args, **kwargs):
        response = self.fetch(*args, **kwargs)
        response.rethrow()
        return json_decode(response.body)

88

89
class HelloWorldRequestHandler(RequestHandler):
B
Ben Darnell 已提交
90 91 92
    def initialize(self, protocol="http"):
        self.expected_protocol = protocol

93
    def get(self):
94 95
        if self.request.protocol != self.expected_protocol:
            raise Exception("unexpected protocol")
96 97
        self.finish("Hello world")

98 99 100
    def post(self):
        self.finish("Got %d bytes in POST" % len(self.request.body))

101

102 103 104 105 106 107
# In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2
# ClientHello messages, which are rejected by SSLv3 and TLSv1
# servers.  Note that while the OPENSSL_VERSION_INFO was formally
# introduced in python3.2, it was present but undocumented in
# python 2.7
skipIfOldSSL = unittest.skipIf(
B
Ben Darnell 已提交
108 109 110
    getattr(ssl, "OPENSSL_VERSION_INFO", (0, 0)) < (1, 0),
    "old version of ssl module and/or openssl",
)
111 112


113
class BaseSSLTest(AsyncHTTPSTestCase):
114
    def get_app(self):
B
Ben Darnell 已提交
115
        return Application([("/", HelloWorldRequestHandler, dict(protocol="https"))])
116

117

118
class SSLTestMixin(object):
119
    def get_ssl_options(self):
B
Ben Darnell 已提交
120 121
        return dict(
            ssl_version=self.get_ssl_version(),
122
            **AsyncHTTPSTestCase.default_ssl_options(),
B
Ben Darnell 已提交
123
        )
124 125 126 127

    def get_ssl_version(self):
        raise NotImplementedError()

B
Ben Darnell 已提交
128
    def test_ssl(self: typing.Any):
B
Ben Darnell 已提交
129
        response = self.fetch("/")
130
        self.assertEqual(response.body, b"Hello world")
B
Ben Darnell 已提交
131

B
Ben Darnell 已提交
132
    def test_large_post(self: typing.Any):
B
Ben Darnell 已提交
133
        response = self.fetch("/", method="POST", body="A" * 5000)
134
        self.assertEqual(response.body, b"Got 5000 bytes in POST")
135

B
Ben Darnell 已提交
136
    def test_non_ssl_request(self: typing.Any):
137 138 139
        # Make sure the server closes the connection when it gets a non-ssl
        # connection, rather than waiting for a timeout or otherwise
        # misbehaving.
B
Ben Darnell 已提交
140 141
        with ExpectLog(gen_log, "(SSL Error|uncaught exception)"):
            with ExpectLog(gen_log, "Uncaught exception", required=False):
B
Ben Darnell 已提交
142
                with self.assertRaises((IOError, HTTPError)):  # type: ignore
143
                    self.fetch(
B
Ben Darnell 已提交
144
                        self.get_url("/").replace("https:", "http:"),
145 146
                        request_timeout=3600,
                        connect_timeout=3600,
B
Ben Darnell 已提交
147 148
                        raise_error=True,
                    )
149

B
Ben Darnell 已提交
150
    def test_error_logging(self: typing.Any):
151
        # No stack traces are logged for SSL errors.
B
Ben Darnell 已提交
152
        with ExpectLog(gen_log, "SSL Error") as expect_log:
B
Ben Darnell 已提交
153
            with self.assertRaises((IOError, HTTPError)):  # type: ignore
B
Ben Darnell 已提交
154 155 156
                self.fetch(
                    self.get_url("/").replace("https:", "http:"), raise_error=True
                )
157 158
        self.assertFalse(expect_log.logged_stack)

B
Ben Darnell 已提交
159

160 161 162 163
# Python's SSL implementation differs significantly between versions.
# For example, SSLv3 and TLSv1 throw an exception if you try to read
# from the socket before the handshake is complete, but the default
# of SSLv23 allows it.
164 165


166
class SSLv23Test(BaseSSLTest, SSLTestMixin):
167 168 169 170
    def get_ssl_version(self):
        return ssl.PROTOCOL_SSLv23


171
@skipIfOldSSL
172
class SSLv3Test(BaseSSLTest, SSLTestMixin):
173 174 175
    def get_ssl_version(self):
        return ssl.PROTOCOL_SSLv3

176

177
@skipIfOldSSL
178
class TLSv1Test(BaseSSLTest, SSLTestMixin):
179 180
    def get_ssl_version(self):
        return ssl.PROTOCOL_TLSv1
181

A
Alek Storm 已提交
182

183 184
class SSLContextTest(BaseSSLTest, SSLTestMixin):
    def get_ssl_options(self):
B
Ben Darnell 已提交
185 186 187
        context = ssl_options_to_context(
            AsyncHTTPSTestCase.get_ssl_options(self), server_side=True
        )
188 189 190 191
        assert isinstance(context, ssl.SSLContext)
        return context


192 193 194
class BadSSLOptionsTest(unittest.TestCase):
    def test_missing_arguments(self):
        application = Application()
B
Ben Darnell 已提交
195 196 197 198 199 200
        self.assertRaises(
            KeyError,
            HTTPServer,
            application,
            ssl_options={"keyfile": "/__missing__.crt"},
        )
201 202

    def test_missing_key(self):
V
Vladlen Y. Koshelev 已提交
203
        """A missing SSL key should cause an immediate exception."""
204 205 206

        application = Application()
        module_dir = os.path.dirname(__file__)
B
Ben Darnell 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
        existing_certificate = os.path.join(module_dir, "test.crt")
        existing_key = os.path.join(module_dir, "test.key")

        self.assertRaises(
            (ValueError, IOError),
            HTTPServer,
            application,
            ssl_options={"certfile": "/__mising__.crt"},
        )
        self.assertRaises(
            (ValueError, IOError),
            HTTPServer,
            application,
            ssl_options={
                "certfile": existing_certificate,
                "keyfile": "/__missing__.key",
            },
        )
225 226

        # This actually works because both files exist
B
Ben Darnell 已提交
227 228 229 230
        HTTPServer(
            application,
            ssl_options={"certfile": existing_certificate, "keyfile": existing_key},
        )
231 232


233 234
class MultipartTestHandler(RequestHandler):
    def post(self):
B
Ben Darnell 已提交
235 236 237 238 239 240 241 242
        self.finish(
            {
                "header": self.request.headers["X-Header-Encoding-Test"],
                "argument": self.get_argument("argument"),
                "filename": self.request.files["files"][0].filename,
                "filebody": _unicode(self.request.files["files"][0]["body"]),
            }
        )
243

244

245
# This test is also called from wsgi_test
246
class HTTPConnectionTest(AsyncHTTPTestCase):
247
    def get_handlers(self):
B
Ben Darnell 已提交
248 249 250 251
        return [
            ("/multipart", MultipartTestHandler),
            ("/hello", HelloWorldRequestHandler),
        ]
252

253
    def get_app(self):
254
        return Application(self.get_handlers())
255

256
    def raw_fetch(self, headers, body, newline=b"\r\n"):
257
        with closing(IOStream(socket.socket())) as stream:
B
Ben Darnell 已提交
258 259 260
            self.io_loop.run_sync(
                lambda: stream.connect(("127.0.0.1", self.get_http_port()))
            )
261
            stream.write(
B
Ben Darnell 已提交
262 263 264 265 266
                newline.join(headers + [utf8("Content-Length: %d" % len(body))])
                + newline
                + newline
                + body
            )
267 268 269
            start_line, headers, body = self.io_loop.run_sync(
                lambda: read_stream_body(stream)
            )
270
            return body
271 272 273 274

    def test_multipart_form(self):
        # Encodings here are tricky:  Headers are latin1, bodies can be
        # anything (we use utf8 by default).
B
Ben Darnell 已提交
275 276 277 278 279 280 281 282 283 284
        response = self.raw_fetch(
            [
                b"POST /multipart HTTP/1.0",
                b"Content-Type: multipart/form-data; boundary=1234567890",
                b"X-Header-encoding-test: \xe9",
            ],
            b"\r\n".join(
                [
                    b"Content-Disposition: form-data; name=argument",
                    b"",
B
Ben Darnell 已提交
285
                    "\u00e1".encode("utf-8"),
B
Ben Darnell 已提交
286
                    b"--1234567890",
B
Ben Darnell 已提交
287
                    'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode(
B
Ben Darnell 已提交
288 289 290
                        "utf8"
                    ),
                    b"",
B
Ben Darnell 已提交
291
                    "\u00fa".encode("utf-8"),
B
Ben Darnell 已提交
292 293 294 295 296
                    b"--1234567890--",
                    b"",
                ]
            ),
        )
297
        data = json_decode(response)
B
Ben Darnell 已提交
298 299 300 301
        self.assertEqual("\u00e9", data["header"])
        self.assertEqual("\u00e1", data["argument"])
        self.assertEqual("\u00f3", data["filename"])
        self.assertEqual("\u00fa", data["filebody"])
302

303 304 305
    def test_newlines(self):
        # We support both CRLF and bare LF as line separators.
        for newline in (b"\r\n", b"\n"):
B
Ben Darnell 已提交
306 307
            response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", newline=newline)
            self.assertEqual(response, b"Hello world")
308

309
    @gen_test
310 311 312 313
    def test_100_continue(self):
        # Run through a 100-continue interaction by hand:
        # When given Expect: 100-continue, we get a 100 response after the
        # headers, and then the real response after the body.
314
        stream = IOStream(socket.socket())
315
        yield stream.connect(("127.0.0.1", self.get_http_port()))
B
Ben Darnell 已提交
316 317 318 319 320 321 322 323 324 325 326
        yield stream.write(
            b"\r\n".join(
                [
                    b"POST /hello HTTP/1.1",
                    b"Content-Length: 1024",
                    b"Expect: 100-continue",
                    b"Connection: close",
                    b"\r\n",
                ]
            )
        )
327
        data = yield stream.read_until(b"\r\n\r\n")
328 329
        self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
        stream.write(b"a" * 1024)
330
        first_line = yield stream.read_until(b"\r\n")
331
        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
332
        header_data = yield stream.read_until(b"\r\n\r\n")
B
Ben Darnell 已提交
333
        headers = HTTPHeaders.parse(native_str(header_data.decode("latin1")))
334
        body = yield stream.read_bytes(int(headers["Content-Length"]))
335
        self.assertEqual(body, b"Got 1024 bytes in POST")
336
        stream.close()
337

338

339 340
class EchoHandler(RequestHandler):
    def get(self):
341
        self.write(recursive_unicode(self.request.arguments))
342

343 344
    def post(self):
        self.write(recursive_unicode(self.request.arguments))
345

B
Ben Darnell 已提交
346

347 348
class TypeCheckHandler(RequestHandler):
    def prepare(self):
349
        self.errors = {}  # type: Dict[str, str]
350
        fields = [
B
Ben Darnell 已提交
351 352 353 354 355 356 357 358
            ("method", str),
            ("uri", str),
            ("version", str),
            ("remote_ip", str),
            ("protocol", str),
            ("host", str),
            ("path", str),
            ("query", str),
359
        ]
360 361 362
        for field, expected_type in fields:
            self.check_type(field, getattr(self.request, field), expected_type)

B
Ben Darnell 已提交
363 364
        self.check_type("header_key", list(self.request.headers.keys())[0], str)
        self.check_type("header_value", list(self.request.headers.values())[0], str)
365

B
Ben Darnell 已提交
366 367 368 369
        self.check_type("cookie_key", list(self.request.cookies.keys())[0], str)
        self.check_type(
            "cookie_value", list(self.request.cookies.values())[0].value, str
        )
370 371
        # secure cookies

B
Ben Darnell 已提交
372 373
        self.check_type("arg_key", list(self.request.arguments.keys())[0], str)
        self.check_type("arg_value", list(self.request.arguments.values())[0][0], bytes)
374 375

    def post(self):
B
Ben Darnell 已提交
376
        self.check_type("body", self.request.body, bytes)
377 378 379 380 381 382 383 384
        self.write(self.errors)

    def get(self):
        self.write(self.errors)

    def check_type(self, name, obj, expected_type):
        actual_type = type(obj)
        if expected_type != actual_type:
B
Ben Darnell 已提交
385
            self.errors[name] = "expected %s, got %s" % (expected_type, actual_type)
386

387

388 389 390 391 392 393 394 395 396 397 398 399 400
class PostEchoHandler(RequestHandler):
    def post(self, *path_args):
        self.write(dict(echo=self.get_argument("data")))


class PostEchoGBKHandler(PostEchoHandler):
    def decode_argument(self, value, name=None):
        try:
            return value.decode("gbk")
        except Exception:
            raise HTTPError(400, "invalid gbk bytes: %r" % value)


401
class HTTPServerTest(AsyncHTTPTestCase):
402
    def get_app(self):
B
Ben Darnell 已提交
403 404 405 406 407
        return Application(
            [
                ("/echo", EchoHandler),
                ("/typecheck", TypeCheckHandler),
                ("//doubleslash", EchoHandler),
408 409
                ("/post_utf8", PostEchoHandler),
                ("/post_gbk", PostEchoGBKHandler),
B
Ben Darnell 已提交
410 411
            ]
        )
412 413 414 415

    def test_query_string_encoding(self):
        response = self.fetch("/echo?foo=%C3%A9")
        data = json_decode(response.body)
B
Ben Darnell 已提交
416
        self.assertEqual(data, {"foo": ["\u00e9"]})
417

418 419 420
    def test_empty_query_string(self):
        response = self.fetch("/echo?foo=&foo=")
        data = json_decode(response.body)
B
Ben Darnell 已提交
421
        self.assertEqual(data, {"foo": ["", ""]})
422

423 424 425
    def test_empty_post_parameters(self):
        response = self.fetch("/echo", method="POST", body="foo=&bar=")
        data = json_decode(response.body)
B
Ben Darnell 已提交
426
        self.assertEqual(data, {"foo": [""], "bar": [""]})
427

428
    def test_types(self):
429 430
        headers = {"Cookie": "foo=bar"}
        response = self.fetch("/typecheck?foo=bar", headers=headers)
431 432 433
        data = json_decode(response.body)
        self.assertEqual(data, {})

B
Ben Darnell 已提交
434 435 436
        response = self.fetch(
            "/typecheck", method="POST", body="foo=bar", headers=headers
        )
437 438 439
        data = json_decode(response.body)
        self.assertEqual(data, {})

B
Ben Darnell 已提交
440 441 442 443 444 445 446 447
    def test_double_slash(self):
        # urlparse.urlsplit (which tornado.httpserver used to use
        # incorrectly) would parse paths beginning with "//" as
        # protocol-relative urls.
        response = self.fetch("//doubleslash")
        self.assertEqual(200, response.code)
        self.assertEqual(json_decode(response.body), {})

448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
    def test_post_encodings(self):
        headers = {"Content-Type": "application/x-www-form-urlencoded"}
        uni_text = "chinese: \u5f20\u4e09"
        for enc in ("utf8", "gbk"):
            for quote in (True, False):
                with self.subTest(enc=enc, quote=quote):
                    bin_text = uni_text.encode(enc)
                    if quote:
                        bin_text = urllib.parse.quote(bin_text).encode("ascii")
                    response = self.fetch(
                        "/post_" + enc,
                        method="POST",
                        headers=headers,
                        body=(b"data=" + bin_text),
                    )
                    self.assertEqual(json_decode(response.body), {"echo": uni_text})
464

465 466 467

class HTTPServerRawTest(AsyncHTTPTestCase):
    def get_app(self):
B
Ben Darnell 已提交
468
        return Application([("/echo", EchoHandler)])
469 470

    def setUp(self):
P
Poruri Sai Rahul 已提交
471
        super().setUp()
472
        self.stream = IOStream(socket.socket())
B
Ben Darnell 已提交
473 474 475
        self.io_loop.run_sync(
            lambda: self.stream.connect(("127.0.0.1", self.get_http_port()))
        )
476 477 478

    def tearDown(self):
        self.stream.close()
P
Poruri Sai Rahul 已提交
479
        super().tearDown()
480 481 482

    def test_empty_request(self):
        self.stream.close()
483 484 485
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

486
    def test_malformed_first_line_response(self):
487
        with ExpectLog(gen_log, ".*Malformed HTTP request line", level=logging.INFO):
B
Ben Darnell 已提交
488
            self.stream.write(b"asdf\r\n\r\n")
489 490 491
            start_line, headers, response = self.io_loop.run_sync(
                lambda: read_stream_body(self.stream)
            )
B
Ben Darnell 已提交
492
            self.assertEqual("HTTP/1.1", start_line.version)
493
            self.assertEqual(400, start_line.code)
B
Ben Darnell 已提交
494
            self.assertEqual("Bad Request", start_line.reason)
495 496

    def test_malformed_first_line_log(self):
497
        with ExpectLog(gen_log, ".*Malformed HTTP request line", level=logging.INFO):
B
Ben Darnell 已提交
498
            self.stream.write(b"asdf\r\n\r\n")
499 500
            # TODO: need an async version of ExpectLog so we don't need
            # hard-coded timeouts here.
B
Ben Darnell 已提交
501
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop)
502 503 504
            self.wait()

    def test_malformed_headers(self):
505 506 507 508 509
        with ExpectLog(
            gen_log,
            ".*Malformed HTTP message.*no colon in header line",
            level=logging.INFO,
        ):
B
Ben Darnell 已提交
510 511
            self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n")
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop)
512 513
            self.wait()

514 515 516
    def test_chunked_request_body(self):
        # Chunked requests are not widely supported and we don't have a way
        # to generate them in AsyncHTTPClient, but HTTPServer will read them.
B
Ben Darnell 已提交
517 518
        self.stream.write(
            b"""\
519 520 521 522
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded

523 524 525 526
4
foo=
3
bar
527 528
0

B
Ben Darnell 已提交
529 530 531 532
""".replace(
                b"\n", b"\r\n"
            )
        )
533 534 535
        start_line, headers, response = self.io_loop.run_sync(
            lambda: read_stream_body(self.stream)
        )
B
Ben Darnell 已提交
536
        self.assertEqual(json_decode(response), {"foo": ["bar"]})
537 538 539 540

    def test_chunked_request_uppercase(self):
        # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is
        # case-insensitive.
B
Ben Darnell 已提交
541 542
        self.stream.write(
            b"""\
543 544 545 546 547 548 549 550 551 552
POST /echo HTTP/1.1
Transfer-Encoding: Chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

B
Ben Darnell 已提交
553 554 555 556
""".replace(
                b"\n", b"\r\n"
            )
        )
557 558 559
        start_line, headers, response = self.io_loop.run_sync(
            lambda: read_stream_body(self.stream)
        )
B
Ben Darnell 已提交
560
        self.assertEqual(json_decode(response), {"foo": ["bar"]})
561

562 563 564 565 566
    def test_chunked_request_body_invalid_size(self):
        # Only hex digits are allowed in chunk sizes. Python's int() function
        # also accepts underscores, so make sure we reject them here.
        self.stream.write(
            b"""\
D
daftshady 已提交
567
POST /echo HTTP/1.1
568
Transfer-Encoding: chunked
D
daftshady 已提交
569

570 571 572
1_a
1234567890abcdef1234567890
0
D
daftshady 已提交
573

B
Ben Darnell 已提交
574
""".replace(
575
                b"\n", b"\r\n"
B
Ben Darnell 已提交
576
            )
577
        )
578 579 580 581
        with ExpectLog(gen_log, ".*invalid chunk size", level=logging.INFO):
            start_line, headers, response = self.io_loop.run_sync(
                lambda: read_stream_body(self.stream)
            )
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
        self.assertEqual(400, start_line.code)

    @gen_test
    def test_invalid_content_length(self):
        # HTTP only allows decimal digits in content-length. Make sure we don't
        # accept anything else, with special attention to things accepted by the
        # python int() function (leading plus signs and internal underscores).
        test_cases = [
            ("alphabetic", "foo"),
            ("leading plus", "+10"),
            ("internal underscore", "1_0"),
        ]
        for name, value in test_cases:
            with self.subTest(name=name), closing(IOStream(socket.socket())) as stream:
                with ExpectLog(
                    gen_log,
                    ".*Only integer Content-Length is allowed",
                    level=logging.INFO,
                ):
                    yield stream.connect(("127.0.0.1", self.get_http_port()))
                    stream.write(
                        utf8(
                            textwrap.dedent(
                                f"""\
                            POST /echo HTTP/1.1
                            Content-Length: {value}
                            Connection: close

                            1234567890
                            """
                            ).replace("\n", "\r\n")
                        )
                    )
                    yield stream.read_until_close()
D
daftshady 已提交
616

617

618 619 620
class XHeaderTest(HandlerBaseTestCase):
    class Handler(RequestHandler):
        def get(self):
B
Ben Darnell 已提交
621 622 623 624 625 626 627
            self.set_header("request-version", self.request.version)
            self.write(
                dict(
                    remote_ip=self.request.remote_ip,
                    remote_protocol=self.request.protocol,
                )
            )
628 629

    def get_httpserver_options(self):
B
Ben Darnell 已提交
630
        return dict(xheaders=True, trusted_downstream=["5.5.5.5"])
631 632

    def test_ip_headers(self):
633
        self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1")
634 635 636

        valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
        self.assertEqual(
B
Ben Darnell 已提交
637 638
            self.fetch_json("/", headers=valid_ipv4)["remote_ip"], "4.4.4.4"
        )
639

640 641
        valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4"}
        self.assertEqual(
B
Ben Darnell 已提交
642 643
            self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], "4.4.4.4"
        )
644

645 646 647
        valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"}
        self.assertEqual(
            self.fetch_json("/", headers=valid_ipv6)["remote_ip"],
B
Ben Darnell 已提交
648 649
            "2620:0:1cfe:face:b00c::3",
        )
650

651 652 653
        valid_ipv6_list = {"X-Forwarded-For": "::1, 2620:0:1cfe:face:b00c::3"}
        self.assertEqual(
            self.fetch_json("/", headers=valid_ipv6_list)["remote_ip"],
B
Ben Darnell 已提交
654 655
            "2620:0:1cfe:face:b00c::3",
        )
656

657 658
        invalid_chars = {"X-Real-IP": "4.4.4.4<script>"}
        self.assertEqual(
B
Ben Darnell 已提交
659 660
            self.fetch_json("/", headers=invalid_chars)["remote_ip"], "127.0.0.1"
        )
661

662 663
        invalid_chars_list = {"X-Forwarded-For": "4.4.4.4, 5.5.5.5<script>"}
        self.assertEqual(
B
Ben Darnell 已提交
664 665
            self.fetch_json("/", headers=invalid_chars_list)["remote_ip"], "127.0.0.1"
        )
666

667 668
        invalid_host = {"X-Real-IP": "www.google.com"}
        self.assertEqual(
B
Ben Darnell 已提交
669 670
            self.fetch_json("/", headers=invalid_host)["remote_ip"], "127.0.0.1"
        )
671

672 673
    def test_trusted_downstream(self):
        valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4, 5.5.5.5"}
674
        resp = self.fetch("/", headers=valid_ipv4_list)
B
Ben Darnell 已提交
675
        if resp.headers["request-version"].startswith("HTTP/2"):
676 677
            # This is a hack - there's nothing that fundamentally requires http/1
            # here but tornado_http2 doesn't support it yet.
B
Ben Darnell 已提交
678
            self.skipTest("requires HTTP/1.x")
679
        result = json_decode(resp.body)
B
Ben Darnell 已提交
680
        self.assertEqual(result["remote_ip"], "4.4.4.4")
681

682 683 684 685 686
    def test_scheme_headers(self):
        self.assertEqual(self.fetch_json("/")["remote_protocol"], "http")

        https_scheme = {"X-Scheme": "https"}
        self.assertEqual(
B
Ben Darnell 已提交
687 688
            self.fetch_json("/", headers=https_scheme)["remote_protocol"], "https"
        )
689 690 691

        https_forwarded = {"X-Forwarded-Proto": "https"}
        self.assertEqual(
B
Ben Darnell 已提交
692 693
            self.fetch_json("/", headers=https_forwarded)["remote_protocol"], "https"
        )
694

695 696 697
        https_multi_forwarded = {"X-Forwarded-Proto": "https , http"}
        self.assertEqual(
            self.fetch_json("/", headers=https_multi_forwarded)["remote_protocol"],
B
Ben Darnell 已提交
698 699
            "http",
        )
700 701 702 703

        http_multi_forwarded = {"X-Forwarded-Proto": "http,https"}
        self.assertEqual(
            self.fetch_json("/", headers=http_multi_forwarded)["remote_protocol"],
B
Ben Darnell 已提交
704 705
            "https",
        )
706

707 708
        bad_forwarded = {"X-Forwarded-Proto": "unknown"}
        self.assertEqual(
B
Ben Darnell 已提交
709 710
            self.fetch_json("/", headers=bad_forwarded)["remote_protocol"], "http"
        )
711 712 713 714


class SSLXHeaderTest(AsyncHTTPSTestCase, HandlerBaseTestCase):
    def get_app(self):
B
Ben Darnell 已提交
715
        return Application([("/", XHeaderTest.Handler)])
716 717

    def get_httpserver_options(self):
P
Poruri Sai Rahul 已提交
718
        output = super().get_httpserver_options()
B
Ben Darnell 已提交
719
        output["xheaders"] = True
720 721 722 723 724 725 726
        return output

    def test_request_without_xprotocol(self):
        self.assertEqual(self.fetch_json("/")["remote_protocol"], "https")

        http_scheme = {"X-Scheme": "http"}
        self.assertEqual(
B
Ben Darnell 已提交
727 728
            self.fetch_json("/", headers=http_scheme)["remote_protocol"], "http"
        )
729 730 731

        bad_scheme = {"X-Scheme": "unknown"}
        self.assertEqual(
B
Ben Darnell 已提交
732 733
            self.fetch_json("/", headers=bad_scheme)["remote_protocol"], "https"
        )
734

735

736 737 738 739 740 741
class ManualProtocolTest(HandlerBaseTestCase):
    class Handler(RequestHandler):
        def get(self):
            self.write(dict(protocol=self.request.protocol))

    def get_httpserver_options(self):
B
Ben Darnell 已提交
742
        return dict(protocol="https")
743 744

    def test_manual_protocol(self):
B
Ben Darnell 已提交
745
        self.assertEqual(self.fetch_json("/")["protocol"], "https")
746

747

B
Ben Darnell 已提交
748 749 750 751
@unittest.skipIf(
    not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin",
    "unix sockets not supported on this platform",
)
752
class UnixSocketTest(AsyncTestCase):
B
Ben Darnell 已提交
753 754 755 756 757 758 759 760 761
    """HTTPServers can listen on Unix sockets too.

    Why would you want to do this?  Nginx can proxy to backends listening
    on unix sockets, for one thing (and managing a namespace for unix
    sockets can be easier than managing a bunch of TCP port numbers).

    Unfortunately, there's no way to specify a unix socket in a url for
    an HTTP client, so we have to test this by hand.
    """
B
Ben Darnell 已提交
762

B
Ben Darnell 已提交
763
    def setUp(self):
P
Poruri Sai Rahul 已提交
764
        super().setUp()
B
Ben Darnell 已提交
765
        self.tmpdir = tempfile.mkdtemp()
766 767 768
        self.sockfile = os.path.join(self.tmpdir, "test.sock")
        sock = netutil.bind_unix_socket(self.sockfile)
        app = Application([("/hello", HelloWorldRequestHandler)])
769
        self.server = HTTPServer(app)
770
        self.server.add_socket(sock)
771
        self.stream = IOStream(socket.socket(socket.AF_UNIX))
772
        self.io_loop.run_sync(lambda: self.stream.connect(self.sockfile))
B
Ben Darnell 已提交
773 774

    def tearDown(self):
775
        self.stream.close()
776
        self.io_loop.run_sync(self.server.close_all_connections)
777
        self.server.stop()
B
Ben Darnell 已提交
778
        shutil.rmtree(self.tmpdir)
P
Poruri Sai Rahul 已提交
779
        super().tearDown()
B
Ben Darnell 已提交
780

781
    @gen_test
B
Ben Darnell 已提交
782
    def test_unix_socket(self):
783
        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
784
        response = yield self.stream.read_until(b"\r\n")
785
        self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
786
        header_data = yield self.stream.read_until(b"\r\n\r\n")
B
Ben Darnell 已提交
787
        headers = HTTPHeaders.parse(header_data.decode("latin1"))
788
        body = yield self.stream.read_bytes(int(headers["Content-Length"]))
789
        self.assertEqual(body, b"Hello world")
790

791
    @gen_test
792 793 794
    def test_unix_socket_bad_request(self):
        # Unix sockets don't have remote addresses so they just return an
        # empty string.
795
        with ExpectLog(gen_log, "Malformed HTTP message from", level=logging.INFO):
796
            self.stream.write(b"garbage\r\n\r\n")
797
            response = yield self.stream.read_until_close()
798
        self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
799

800 801 802 803 804 805 806

class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
B
Ben Darnell 已提交
807

808 809 810
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
B
Ben Darnell 已提交
811
                self.finish("Hello world")
B
Ben Darnell 已提交
812

813
            def post(self):
B
Ben Darnell 已提交
814
                self.finish("Hello world")
815 816 817 818 819

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
B
Ben Darnell 已提交
820
                self.write("".join(chr(i % 256) * 1024 for i in range(512)))
821

822 823 824 825 826 827
        class TransferEncodingChunkedHandler(RequestHandler):
            @gen.coroutine
            def head(self):
                self.write("Hello world")
                yield self.flush()

828
        class FinishOnCloseHandler(RequestHandler):
829 830 831
            def initialize(self, cleanup_event):
                self.cleanup_event = cleanup_event

832
            @gen.coroutine
833 834
            def get(self):
                self.flush()
835
                yield self.cleanup_event.wait()
836 837 838 839 840

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
B
Ben Darnell 已提交
841
                self.finish("closed")
842

843
        self.cleanup_event = Event()
B
Ben Darnell 已提交
844 845 846 847
        return Application(
            [
                ("/", HelloHandler),
                ("/large", LargeHandler),
848
                ("/chunked", TransferEncodingChunkedHandler),
849 850 851 852 853
                (
                    "/finish_on_close",
                    FinishOnCloseHandler,
                    dict(cleanup_event=self.cleanup_event),
                ),
B
Ben Darnell 已提交
854 855
            ]
        )
856 857

    def setUp(self):
P
Poruri Sai Rahul 已提交
858
        super().setUp()
B
Ben Darnell 已提交
859
        self.http_version = b"HTTP/1.1"
860 861 862 863 864 865 866

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

B
Ben Darnell 已提交
867
        if hasattr(self, "stream"):
868
            self.stream.close()
P
Poruri Sai Rahul 已提交
869
        super().tearDown()
870 871

    # The next few methods are a crude manual http client
872
    @gen.coroutine
873
    def connect(self):
874
        self.stream = IOStream(socket.socket())
B
Ben Darnell 已提交
875
        yield self.stream.connect(("127.0.0.1", self.get_http_port()))
876

877
    @gen.coroutine
878
    def read_headers(self):
B
Ben Darnell 已提交
879 880 881 882
        first_line = yield self.stream.read_until(b"\r\n")
        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
        header_bytes = yield self.stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_bytes.decode("latin1"))
883
        raise gen.Return(headers)
884

885
    @gen.coroutine
886
    def read_response(self):
887
        self.headers = yield self.read_headers()
B
Ben Darnell 已提交
888 889
        body = yield self.stream.read_bytes(int(self.headers["Content-Length"]))
        self.assertEqual(b"Hello world", body)
890 891 892 893 894

    def close(self):
        self.stream.close()
        del self.stream

895
    @gen_test
896
    def test_two_requests(self):
897
        yield self.connect()
B
Ben Darnell 已提交
898
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
899
        yield self.read_response()
B
Ben Darnell 已提交
900
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
901
        yield self.read_response()
902 903
        self.close()

904
    @gen_test
905
    def test_request_close(self):
906
        yield self.connect()
B
Ben Darnell 已提交
907
        self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
908 909
        yield self.read_response()
        data = yield self.stream.read_until_close()
910
        self.assertTrue(not data)
B
Ben Darnell 已提交
911
        self.assertEqual(self.headers["Connection"], "close")
912 913 914
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
915
    @gen_test
916
    def test_http10(self):
B
Ben Darnell 已提交
917
        self.http_version = b"HTTP/1.0"
918
        yield self.connect()
B
Ben Darnell 已提交
919
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
920 921
        yield self.read_response()
        data = yield self.stream.read_until_close()
922
        self.assertTrue(not data)
B
Ben Darnell 已提交
923
        self.assertTrue("Connection" not in self.headers)
924 925
        self.close()

926
    @gen_test
927
    def test_http10_keepalive(self):
B
Ben Darnell 已提交
928
        self.http_version = b"HTTP/1.0"
929
        yield self.connect()
B
Ben Darnell 已提交
930
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
931
        yield self.read_response()
B
Ben Darnell 已提交
932 933
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
934
        yield self.read_response()
B
Ben Darnell 已提交
935
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
936 937
        self.close()

938
    @gen_test
939
    def test_http10_keepalive_extra_crlf(self):
B
Ben Darnell 已提交
940
        self.http_version = b"HTTP/1.0"
941
        yield self.connect()
B
Ben Darnell 已提交
942
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n")
943
        yield self.read_response()
B
Ben Darnell 已提交
944 945
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
946
        yield self.read_response()
B
Ben Darnell 已提交
947
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
948 949
        self.close()

950
    @gen_test
951
    def test_pipelined_requests(self):
952
        yield self.connect()
B
Ben Darnell 已提交
953
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
954 955
        yield self.read_response()
        yield self.read_response()
956 957
        self.close()

958
    @gen_test
959
    def test_pipelined_cancel(self):
960
        yield self.connect()
B
Ben Darnell 已提交
961
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
962
        # only read once
963
        yield self.read_response()
964 965
        self.close()

966
    @gen_test
967
    def test_cancel_during_download(self):
968
        yield self.connect()
B
Ben Darnell 已提交
969
        self.stream.write(b"GET /large HTTP/1.1\r\n\r\n")
970 971
        yield self.read_headers()
        yield self.stream.read_bytes(1024)
972 973
        self.close()

974
    @gen_test
975
    def test_finish_while_closed(self):
976
        yield self.connect()
B
Ben Darnell 已提交
977
        self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n")
978
        yield self.read_headers()
979
        self.close()
980 981
        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()
982

983
    @gen_test
984
    def test_keepalive_chunked(self):
B
Ben Darnell 已提交
985
        self.http_version = b"HTTP/1.0"
986
        yield self.connect()
B
Ben Darnell 已提交
987 988 989 990 991 992 993 994
        self.stream.write(
            b"POST / HTTP/1.0\r\n"
            b"Connection: keep-alive\r\n"
            b"Transfer-Encoding: chunked\r\n"
            b"\r\n"
            b"0\r\n"
            b"\r\n"
        )
995
        yield self.read_response()
B
Ben Darnell 已提交
996 997
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
998
        yield self.read_response()
B
Ben Darnell 已提交
999
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
1000 1001
        self.close()

1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
    @gen_test
    def test_keepalive_chunked_head_no_body(self):
        yield self.connect()
        self.stream.write(b"HEAD /chunked HTTP/1.1\r\n\r\n")
        yield self.read_headers()

        self.stream.write(b"HEAD /chunked HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        self.close()

1012

B
Ben Darnell 已提交
1013
class GzipBaseTest(AsyncHTTPTestCase):
1014
    def get_app(self):
B
Ben Darnell 已提交
1015
        return Application([("/", EchoHandler)])
1016 1017 1018

    def post_gzip(self, body):
        bytesio = BytesIO()
B
Ben Darnell 已提交
1019
        gzip_file = gzip.GzipFile(mode="w", fileobj=bytesio)
1020 1021 1022
        gzip_file.write(utf8(body))
        gzip_file.close()
        compressed_body = bytesio.getvalue()
B
Ben Darnell 已提交
1023 1024 1025 1026 1027 1028
        return self.fetch(
            "/",
            method="POST",
            body=compressed_body,
            headers={"Content-Encoding": "gzip"},
        )
1029 1030

    def test_uncompressed(self):
B
Ben Darnell 已提交
1031
        response = self.fetch("/", method="POST", body="foo=bar")
B
Ben Darnell 已提交
1032
        self.assertEqual(json_decode(response.body), {"foo": ["bar"]})
1033

B
Ben Darnell 已提交
1034

1035 1036
class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
    def get_httpserver_options(self):
1037
        return dict(decompress_request=True)
1038 1039

    def test_gzip(self):
B
Ben Darnell 已提交
1040
        response = self.post_gzip("foo=bar")
B
Ben Darnell 已提交
1041
        self.assertEqual(json_decode(response.body), {"foo": ["bar"]})
1042

1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055
    def test_gzip_case_insensitive(self):
        # https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.2.1
        bytesio = BytesIO()
        gzip_file = gzip.GzipFile(mode="w", fileobj=bytesio)
        gzip_file.write(utf8("foo=bar"))
        gzip_file.close()
        compressed_body = bytesio.getvalue()
        response = self.fetch(
            "/",
            method="POST",
            body=compressed_body,
            headers={"Content-Encoding": "GZIP"},
        )
B
Ben Darnell 已提交
1056
        self.assertEqual(json_decode(response.body), {"foo": ["bar"]})
1057

B
Ben Darnell 已提交
1058

1059 1060 1061 1062 1063 1064
class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase):
    def test_gzip_unsupported(self):
        # Gzip support is opt-in; without it the server fails to parse
        # the body (but parsing form bodies is currently just a log message,
        # not a fatal error).
        with ExpectLog(gen_log, "Unsupported Content-Encoding"):
B
Ben Darnell 已提交
1065
            response = self.post_gzip("foo=bar")
1066
        self.assertEqual(json_decode(response.body), {})
1067 1068 1069 1070


class StreamingChunkSizeTest(AsyncHTTPTestCase):
    # 50 characters long, and repetitive so it can be compressed.
B
Ben Darnell 已提交
1071
    BODY = b"01234567890123456789012345678901234567890123456789"
1072 1073 1074 1075 1076
    CHUNK_SIZE = 16

    def get_http_client(self):
        # body_producer doesn't work on curl_httpclient, so override the
        # configured AsyncHTTPClient implementation.
1077
        return SimpleAsyncHTTPClient()
1078 1079

    def get_httpserver_options(self):
1080
        return dict(chunk_size=self.CHUNK_SIZE, decompress_request=True)
1081 1082 1083 1084 1085 1086

    class MessageDelegate(HTTPMessageDelegate):
        def __init__(self, connection):
            self.connection = connection

        def headers_received(self, start_line, headers):
1087
            self.chunk_lengths = []  # type: List[int]
1088 1089 1090 1091 1092 1093 1094

        def data_received(self, chunk):
            self.chunk_lengths.append(len(chunk))

        def finish(self):
            response_body = utf8(json_encode(self.chunk_lengths))
            self.connection.write_headers(
B
Ben Darnell 已提交
1095 1096 1097
                ResponseStartLine("HTTP/1.1", 200, "OK"),
                HTTPHeaders({"Content-Length": str(len(response_body))}),
            )
1098 1099 1100 1101 1102
            self.connection.write(response_body)
            self.connection.finish()

    def get_app(self):
        class App(HTTPServerConnectionDelegate):
1103 1104
            def start_request(self, server_conn, request_conn):
                return StreamingChunkSizeTest.MessageDelegate(request_conn)
B
Ben Darnell 已提交
1105

1106 1107 1108
        return App()

    def fetch_chunk_sizes(self, **kwargs):
B
Ben Darnell 已提交
1109
        response = self.fetch("/", method="POST", **kwargs)
1110 1111 1112 1113
        response.rethrow()
        chunks = json_decode(response.body)
        self.assertEqual(len(self.BODY), sum(chunks))
        for chunk_size in chunks:
B
Ben Darnell 已提交
1114 1115 1116 1117
            self.assertLessEqual(
                chunk_size, self.CHUNK_SIZE, "oversized chunk: " + str(chunks)
            )
            self.assertGreater(chunk_size, 0, "empty chunk: " + str(chunks))
1118 1119 1120 1121
        return chunks

    def compress(self, body):
        bytesio = BytesIO()
B
Ben Darnell 已提交
1122
        gzfile = gzip.GzipFile(mode="w", fileobj=bytesio)
1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135
        gzfile.write(body)
        gzfile.close()
        compressed = bytesio.getvalue()
        if len(compressed) >= len(body):
            raise Exception("body did not shrink when compressed")
        return compressed

    def test_regular_body(self):
        chunks = self.fetch_chunk_sizes(body=self.BODY)
        # Without compression we know exactly what to expect.
        self.assertEqual([16, 16, 16, 2], chunks)

    def test_compressed_body(self):
B
Ben Darnell 已提交
1136 1137 1138
        self.fetch_chunk_sizes(
            body=self.compress(self.BODY), headers={"Content-Encoding": "gzip"}
        )
B
Ben Darnell 已提交
1139 1140
        # Compression creates irregular boundaries so the assertions
        # in fetch_chunk_sizes are as specific as we can get.
1141 1142 1143 1144 1145

    def test_chunked_body(self):
        def body_producer(write):
            write(self.BODY[:20])
            write(self.BODY[20:])
B
Ben Darnell 已提交
1146

1147 1148 1149 1150 1151 1152 1153
        chunks = self.fetch_chunk_sizes(body_producer=body_producer)
        # HTTP chunk boundaries translate to application-visible breaks
        self.assertEqual([16, 4, 16, 14], chunks)

    def test_chunked_compressed(self):
        compressed = self.compress(self.BODY)
        self.assertGreater(len(compressed), 20)
B
Ben Darnell 已提交
1154

1155 1156 1157
        def body_producer(write):
            write(compressed[:20])
            write(compressed[20:])
B
Ben Darnell 已提交
1158 1159 1160 1161

        self.fetch_chunk_sizes(
            body_producer=body_producer, headers={"Content-Encoding": "gzip"}
        )
1162 1163


1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203
class InvalidOutputContentLengthTest(AsyncHTTPTestCase):
    class MessageDelegate(HTTPMessageDelegate):
        def __init__(self, connection):
            self.connection = connection

        def headers_received(self, start_line, headers):
            content_lengths = {
                "normal": "10",
                "alphabetic": "foo",
                "leading plus": "+10",
                "underscore": "1_0",
            }
            self.connection.write_headers(
                ResponseStartLine("HTTP/1.1", 200, "OK"),
                HTTPHeaders({"Content-Length": content_lengths[headers["x-test"]]}),
            )
            self.connection.write(b"1234567890")
            self.connection.finish()

    def get_app(self):
        class App(HTTPServerConnectionDelegate):
            def start_request(self, server_conn, request_conn):
                return InvalidOutputContentLengthTest.MessageDelegate(request_conn)

        return App()

    def test_invalid_output_content_length(self):
        with self.subTest("normal"):
            response = self.fetch("/", method="GET", headers={"x-test": "normal"})
            response.rethrow()
            self.assertEqual(response.body, b"1234567890")
        for test in ["alphabetic", "leading plus", "underscore"]:
            with self.subTest(test):
                # This log matching could be tighter but I think I'm already
                # over-testing here.
                with ExpectLog(app_log, "Uncaught exception"):
                    with self.assertRaises(HTTPError):
                        self.fetch("/", method="GET", headers={"x-test": test})


1204 1205
class MaxHeaderSizeTest(AsyncHTTPTestCase):
    def get_app(self):
B
Ben Darnell 已提交
1206
        return Application([("/", HelloWorldRequestHandler)])
1207 1208 1209 1210 1211

    def get_httpserver_options(self):
        return dict(max_header_size=1024)

    def test_small_headers(self):
B
Ben Darnell 已提交
1212
        response = self.fetch("/", headers={"X-Filler": "a" * 100})
1213 1214 1215 1216
        response.rethrow()
        self.assertEqual(response.body, b"Hello world")

    def test_large_headers(self):
1217
        with ExpectLog(gen_log, "Unsatisfiable read", required=False):
1218
            try:
B
Ben Darnell 已提交
1219
                self.fetch("/", headers={"X-Filler": "a" * 1000}, raise_error=True)
1220 1221 1222 1223
                self.fail("did not raise expected exception")
            except HTTPError as e:
                # 431 is "Request Header Fields Too Large", defined in RFC
                # 6585. However, many implementations just close the
1224 1225 1226
                # connection in this case, resulting in a missing response.
                if e.response is not None:
                    self.assertIn(e.response.code, (431, 599))
1227 1228 1229 1230 1231


@skipOnTravis
class IdleTimeoutTest(AsyncHTTPTestCase):
    def get_app(self):
B
Ben Darnell 已提交
1232
        return Application([("/", HelloWorldRequestHandler)])
1233 1234 1235 1236 1237

    def get_httpserver_options(self):
        return dict(idle_connection_timeout=0.1)

    def setUp(self):
P
Poruri Sai Rahul 已提交
1238
        super().setUp()
1239
        self.streams = []  # type: List[IOStream]
1240 1241

    def tearDown(self):
P
Poruri Sai Rahul 已提交
1242
        super().tearDown()
1243 1244 1245
        for stream in self.streams:
            stream.close()

1246
    @gen.coroutine
1247 1248
    def connect(self):
        stream = IOStream(socket.socket())
B
Ben Darnell 已提交
1249
        yield stream.connect(("127.0.0.1", self.get_http_port()))
1250
        self.streams.append(stream)
1251
        raise gen.Return(stream)
1252

1253
    @gen_test
1254
    def test_unused_connection(self):
1255 1256 1257 1258
        stream = yield self.connect()
        event = Event()
        stream.set_close_callback(event.set)
        yield event.wait()
1259

1260
    @gen_test
1261
    def test_idle_after_use(self):
1262 1263 1264
        stream = yield self.connect()
        event = Event()
        stream.set_close_callback(event.set)
1265 1266 1267 1268

        # Use the connection twice to make sure keep-alives are working
        for i in range(2):
            stream.write(b"GET / HTTP/1.1\r\n\r\n")
1269 1270
            yield stream.read_until(b"\r\n\r\n")
            data = yield stream.read_bytes(11)
1271 1272 1273
            self.assertEqual(data, b"Hello world")

        # Now let the timeout trigger and close the connection.
1274
        yield event.wait()
1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288


class BodyLimitsTest(AsyncHTTPTestCase):
    def get_app(self):
        class BufferedHandler(RequestHandler):
            def put(self):
                self.write(str(len(self.request.body)))

        @stream_request_body
        class StreamingHandler(RequestHandler):
            def initialize(self):
                self.bytes_read = 0

            def prepare(self):
B
Ben Darnell 已提交
1289
                conn = typing.cast(HTTP1Connection, self.request.connection)
B
Ben Darnell 已提交
1290
                if "expected_size" in self.request.arguments:
B
Ben Darnell 已提交
1291
                    conn.set_max_body_size(int(self.get_argument("expected_size")))
B
Ben Darnell 已提交
1292
                if "body_timeout" in self.request.arguments:
B
Ben Darnell 已提交
1293
                    conn.set_body_timeout(float(self.get_argument("body_timeout")))
1294 1295 1296 1297 1298 1299 1300

            def data_received(self, data):
                self.bytes_read += len(data)

            def put(self):
                self.write(str(self.bytes_read))

B
Ben Darnell 已提交
1301 1302 1303
        return Application(
            [("/buffered", BufferedHandler), ("/streaming", StreamingHandler)]
        )
1304 1305 1306 1307 1308 1309 1310

    def get_httpserver_options(self):
        return dict(body_timeout=3600, max_body_size=4096)

    def get_http_client(self):
        # body_producer doesn't work on curl_httpclient, so override the
        # configured AsyncHTTPClient implementation.
1311
        return SimpleAsyncHTTPClient()
1312 1313

    def test_small_body(self):
B
Ben Darnell 已提交
1314 1315 1316 1317
        response = self.fetch("/buffered", method="PUT", body=b"a" * 4096)
        self.assertEqual(response.body, b"4096")
        response = self.fetch("/streaming", method="PUT", body=b"a" * 4096)
        self.assertEqual(response.body, b"4096")
1318 1319

    def test_large_body_buffered(self):
1320
        with ExpectLog(gen_log, ".*Content-Length too long", level=logging.INFO):
B
Ben Darnell 已提交
1321
            response = self.fetch("/buffered", method="PUT", body=b"a" * 10240)
1322
        self.assertEqual(response.code, 400)
1323

B
Ben Darnell 已提交
1324
    @unittest.skipIf(os.name == "nt", "flaky on windows")
1325
    def test_large_body_buffered_chunked(self):
1326
        # This test is flaky on windows for unknown reasons.
1327
        with ExpectLog(gen_log, ".*chunked body too large", level=logging.INFO):
B
Ben Darnell 已提交
1328 1329 1330 1331 1332
            response = self.fetch(
                "/buffered",
                method="PUT",
                body_producer=lambda write: write(b"a" * 10240),
            )
1333
        self.assertEqual(response.code, 400)
1334 1335

    def test_large_body_streaming(self):
1336
        with ExpectLog(gen_log, ".*Content-Length too long", level=logging.INFO):
B
Ben Darnell 已提交
1337
            response = self.fetch("/streaming", method="PUT", body=b"a" * 10240)
1338
        self.assertEqual(response.code, 400)
1339

B
Ben Darnell 已提交
1340
    @unittest.skipIf(os.name == "nt", "flaky on windows")
1341
    def test_large_body_streaming_chunked(self):
1342
        with ExpectLog(gen_log, ".*chunked body too large", level=logging.INFO):
B
Ben Darnell 已提交
1343 1344 1345 1346 1347
            response = self.fetch(
                "/streaming",
                method="PUT",
                body_producer=lambda write: write(b"a" * 10240),
            )
1348
        self.assertEqual(response.code, 400)
1349 1350

    def test_large_body_streaming_override(self):
B
Ben Darnell 已提交
1351 1352 1353 1354
        response = self.fetch(
            "/streaming?expected_size=10240", method="PUT", body=b"a" * 10240
        )
        self.assertEqual(response.body, b"10240")
1355 1356

    def test_large_body_streaming_chunked_override(self):
B
Ben Darnell 已提交
1357 1358 1359 1360 1361 1362
        response = self.fetch(
            "/streaming?expected_size=10240",
            method="PUT",
            body_producer=lambda write: write(b"a" * 10240),
        )
        self.assertEqual(response.body, b"10240")
1363 1364 1365 1366 1367

    @gen_test
    def test_timeout(self):
        stream = IOStream(socket.socket())
        try:
B
Ben Darnell 已提交
1368
            yield stream.connect(("127.0.0.1", self.get_http_port()))
1369 1370
            # Use a raw stream because AsyncHTTPClient won't let us read a
            # response without finishing a body.
B
Ben Darnell 已提交
1371 1372 1373 1374
            stream.write(
                b"PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n"
                b"Content-Length: 42\r\n\r\n"
            )
1375
            with ExpectLog(gen_log, "Timeout reading body", level=logging.INFO):
1376
                response = yield stream.read_until_close()
B
Ben Darnell 已提交
1377
            self.assertEqual(response, b"")
1378 1379 1380 1381 1382 1383 1384 1385
        finally:
            stream.close()

    @gen_test
    def test_body_size_override_reset(self):
        # The max_body_size override is reset between requests.
        stream = IOStream(socket.socket())
        try:
B
Ben Darnell 已提交
1386
            yield stream.connect(("127.0.0.1", self.get_http_port()))
1387
            # Use a raw stream so we can make sure it's all on one connection.
B
Ben Darnell 已提交
1388 1389 1390 1391 1392
            stream.write(
                b"PUT /streaming?expected_size=10240 HTTP/1.1\r\n"
                b"Content-Length: 10240\r\n\r\n"
            )
            stream.write(b"a" * 10240)
1393
            start_line, headers, response = yield read_stream_body(stream)
B
Ben Darnell 已提交
1394
            self.assertEqual(response, b"10240")
1395
            # Without the ?expected_size parameter, we get the old default value
B
Ben Darnell 已提交
1396 1397 1398
            stream.write(
                b"PUT /streaming HTTP/1.1\r\n" b"Content-Length: 10240\r\n\r\n"
            )
1399
            with ExpectLog(gen_log, ".*Content-Length too long", level=logging.INFO):
1400
                data = yield stream.read_until_close()
B
Ben Darnell 已提交
1401
            self.assertEqual(data, b"HTTP/1.1 400 Bad Request\r\n\r\n")
1402 1403
        finally:
            stream.close()
1404 1405 1406 1407 1408 1409 1410 1411


class LegacyInterfaceTest(AsyncHTTPTestCase):
    def get_app(self):
        # The old request_callback interface does not implement the
        # delegate interface, and writes its response via request.write
        # instead of request.connection.write_headers.
        def handle_request(request):
1412
            self.http1 = request.version.startswith("HTTP/1.")
1413 1414 1415 1416
            if not self.http1:
                # This test will be skipped if we're using HTTP/2,
                # so just close it out cleanly using the modern interface.
                request.connection.write_headers(
B
Ben Darnell 已提交
1417 1418
                    ResponseStartLine("", 200, "OK"), HTTPHeaders()
                )
1419 1420
                request.connection.finish()
                return
1421
            message = b"Hello world"
B
Ben Darnell 已提交
1422 1423 1424
            request.connection.write(
                utf8("HTTP/1.1 200 OK\r\n" "Content-Length: %d\r\n\r\n" % len(message))
            )
B
Ben Darnell 已提交
1425 1426
            request.connection.write(message)
            request.connection.finish()
B
Ben Darnell 已提交
1427

1428 1429 1430
        return handle_request

    def test_legacy_interface(self):
B
Ben Darnell 已提交
1431
        response = self.fetch("/")
1432 1433
        if not self.http1:
            self.skipTest("requires HTTP/1.x")
1434
        self.assertEqual(response.body, b"Hello world")