diff --git a/tests/test_websocket_integration.py b/tests/test_websocket_integration.py index bb47513924b1bbd3c838190f68105d107d9c10ca..1fe0001f0de487984b372779ac599943c46d63cb 100644 --- a/tests/test_websocket_integration.py +++ b/tests/test_websocket_integration.py @@ -53,7 +53,12 @@ class TestInvalidWebsocketResponse(object): transport="websockets" ) - with fake_websocket_broker.serve("200 OK\n".encode("utf8")): + class WebsocketHandler(socketserver.BaseRequestHandler): + def handle(_self): + # Respond with data passed in to serve() + _self.request.sendall("200 OK".encode("utf8")) + + with fake_websocket_broker.serve(WebsocketHandler): with pytest.raises(WebsocketConnectionError) as exc: mqttc.connect("localhost", 1888, keepalive=10) @@ -197,13 +202,14 @@ class TestValidHeaders(object): path=mqtt_path, ) - def create_response_hash(decoded): + def check_path_correct(decoded): # Make sure it connects to the right path - assert re.search("GET {:s} HTTP/1.1".format(mqtt_path), decoded, re.IGNORECASE) is not None + if mqtt_path: + assert re.search("GET {:s} HTTP/1.1".format(mqtt_path), decoded, re.IGNORECASE) is not None response = self._get_callback_handler( init_response_headers, - check_request=create_response_hash, + check_request=check_path_correct, ) with fake_websocket_broker.serve(response): @@ -231,14 +237,15 @@ class TestValidHeaders(object): headers=auth_headers, ) - def create_response_hash(decoded): + def check_headers_used(decoded): # Make sure it connects to the right path - for h in auth_headers: - assert re.search("{:s}: {:s}".format(h, auth_headers[h]), decoded, re.IGNORECASE) is not None + if auth_headers: + for h in auth_headers: + assert re.search("{:s}: {:s}".format(h, auth_headers[h]), decoded, re.IGNORECASE) is not None response = self._get_callback_handler( init_response_headers, - check_request=create_response_hash, + check_request=check_headers_used, ) with fake_websocket_broker.serve(response):