diff --git a/setup.py b/setup.py index 9ff9bd8b882c79b345c1457c9d42d878824a7221..8af2a26486241992b2411ffa42f91a04ad9d1731 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,9 @@ requirements = [] test_requirements = ['pytest'] setup_requirements = ['pytest-runner'] +if sys.version_info < (3, 0): + test_requirements += ['mock'] + setup( name='paho-mqtt', version=__version__, diff --git a/tests/test_websocket_integration.py b/tests/test_websocket_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..63f888cbe1d8108d4433dd1099f844868bde6ede --- /dev/null +++ b/tests/test_websocket_integration.py @@ -0,0 +1,249 @@ +import base64 +import re +import hashlib +from collections import OrderedDict + +from six.moves import socketserver +import pytest +import paho.mqtt.client as client + +from paho.mqtt.client import WebsocketConnectionError +from testsupport.broker import fake_websocket_broker + + +pytestmark = [ + pytest.mark.skipif( + not pytest.config.getoption("--run-integration-tests"), + reason="Specify --run-integration-tests to run these tests", + ), +] + + +@pytest.fixture +def init_response_headers(): + # "Normal" websocket response from server + response_headers = OrderedDict([ + ("Upgrade", "websocket"), + ("Connection", "Upgrade"), + ("Sec-WebSocket-Accept", "testwebsocketkey"), + ("Sec-WebSocket-Protocol", "chat"), + ]) + + return response_headers + + +def get_websocket_response(response_headers): + """ Takes headers and constructs HTTP response + + 'HTTP/1.1 101 Switching Protocols' is the headers for the response, + as expected in client.py + """ + response = "\r\n".join([ + "HTTP/1.1 101 Switching Protocols", + "\r\n".join("{}: {}".format(i, j) for i, j in response_headers.items()), + "\r\n", + ]).encode("utf8") + + return response + + +@pytest.mark.parametrize("proto_ver,proto_name", [ + (client.MQTTv31, "MQIsdp"), + (client.MQTTv311, "MQTT"), +]) +class TestInvalidWebsocketResponse(object): + def test_unexpected_response(self, proto_ver, proto_name, fake_websocket_broker): + """ Server responds with a valid code, but it's not what the client expected """ + + mqttc = client.Client( + "test_unexpected_response", + protocol=proto_ver, + transport="websockets" + ) + + with fake_websocket_broker.serve("200 OK\n"): + with pytest.raises(WebsocketConnectionError) as exc: + mqttc.connect("localhost", 1888, keepalive=10) + + assert str(exc.value) == "WebSocket handshake error" + + +@pytest.mark.parametrize("proto_ver,proto_name", [ + (client.MQTTv31, "MQIsdp"), + (client.MQTTv311, "MQTT"), +]) +class TestBadWebsocketHeaders(object): + """ Testing for basic functionality in checking for headers """ + + def _get_basic_handler(self, response_headers): + """ Get a basic BaseRequestHandler which returns the information in + self._response_headers + """ + + response = get_websocket_response(response_headers) + + class WebsocketHandler(socketserver.BaseRequestHandler): + def handle(_self): + self.data = _self.request.recv(1024).strip() + print("Received '{:s}'".format(self.data.decode("utf8"))) + # Respond with data passed in to serve() + _self.request.sendall(response) + + return WebsocketHandler + + def test_no_upgrade(self, proto_ver, proto_name, fake_websocket_broker, + init_response_headers): + """ Server doesn't respond with 'connection: upgrade' """ + + mqttc = client.Client( + "test_no_upgrade", + protocol=proto_ver, + transport="websockets" + ) + + init_response_headers["Connection"] = "bad" + response = self._get_basic_handler(init_response_headers) + + with fake_websocket_broker.serve(response): + with pytest.raises(WebsocketConnectionError) as exc: + mqttc.connect("localhost", 1888, keepalive=10) + + assert str(exc.value) == "WebSocket handshake error, connection not upgraded" + + def test_bad_secret_key(self, proto_ver, proto_name, fake_websocket_broker, + init_response_headers): + """ Server doesn't give anything after connection: upgrade """ + + mqttc = client.Client( + "test_bad_secret_key", + protocol=proto_ver, + transport="websockets" + ) + + response = self._get_basic_handler(init_response_headers) + + with fake_websocket_broker.serve(response): + with pytest.raises(WebsocketConnectionError) as exc: + mqttc.connect("localhost", 1888, keepalive=10) + + assert str(exc.value) == "WebSocket handshake error, invalid secret key" + + +@pytest.mark.parametrize("proto_ver,proto_name", [ + (client.MQTTv31, "MQIsdp"), + (client.MQTTv311, "MQTT"), +]) +class TestValidHeaders(object): + """ Testing for functionality in request/response headers """ + + def _get_callback_handler(self, response_headers, check_request=None): + """ Get a basic BaseRequestHandler which returns the information in + self._response_headers + """ + + class WebsocketHandler(socketserver.BaseRequestHandler): + def handle(_self): + self.data = _self.request.recv(1024).strip() + print("Received '{:s}'".format(self.data.decode("utf8"))) + + decoded = self.data.decode("utf8") + + if check_request is not None: + check_request(decoded) + + # Create server hash + GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + key = re.search("sec-websocket-key: ([A-Za-z0-9+/=]*)", decoded, re.IGNORECASE).group(1) + + to_hash = "{:s}{:s}".format(key, GUID) + hashed = hashlib.sha1(to_hash.encode("utf8")) + encoded = base64.b64encode(hashed.digest()).decode("utf8") + + response_headers["Sec-WebSocket-Accept"] = encoded + + # Respond with the correct hash + response = get_websocket_response(response_headers) + + _self.request.sendall(response) + + return WebsocketHandler + + def test_successful_connection(self, proto_ver, proto_name, + fake_websocket_broker, + init_response_headers): + """ Connect successfully, on correct path """ + + mqttc = client.Client( + "test_successful_connection", + protocol=proto_ver, + transport="websockets" + ) + + response = self._get_callback_handler(init_response_headers) + + with fake_websocket_broker.serve(response): + mqttc.connect("localhost", 1888, keepalive=10) + + mqttc.disconnect() + + @pytest.mark.parametrize("mqtt_path", [ + "/mqtt" + "/special", + None, + ]) + def test_correct_path(self, proto_ver, proto_name, fake_websocket_broker, + mqtt_path, init_response_headers): + """ Make sure it can connect on user specified paths """ + + mqttc = client.Client( + "test_correct_path", + protocol=proto_ver, + transport="websockets" + ) + + mqttc.ws_set_options( + path=mqtt_path, + ) + + def create_response_hash(decoded): + # Make sure it connects to the right path + assert re.search("GET {:s} HTTP/1.1".format(mqtt_path), decoded, re.IGNORECASE) is not None + + response = self._get_callback_handler(init_response_headers) + + with fake_websocket_broker.serve(response): + mqttc.connect("localhost", 1888, keepalive=10) + + mqttc.disconnect() + + @pytest.mark.parametrize("auth_headers", [ + {"Authorization": "test123"}, + {"Authorization": "test123", "auth2": "abcdef"}, + # Won't be checked, but make sure it still works even if the user passes it + None, + ]) + def test_correct_auth(self, proto_ver, proto_name, fake_websocket_broker, + auth_headers, init_response_headers): + """ Make sure it sends the right auth headers """ + + mqttc = client.Client( + "test_correct_path", + protocol=proto_ver, + transport="websockets" + ) + + mqttc.ws_set_options( + headers=auth_headers, + ) + + def create_response_hash(decoded): + # Make sure it connects to the right path + for h in auth_headers: + assert re.search("{:s}: {:s}".format(h, auth_headers[h]), decoded, re.IGNORECASE) is not None + + response = self._get_callback_handler(init_response_headers) + + with fake_websocket_broker.serve(response): + mqttc.connect("localhost", 1888, keepalive=10) + + mqttc.disconnect() diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 63f888cbe1d8108d4433dd1099f844868bde6ede..0e7111160f49e3c0a25564f8f65b54882b68783b 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -1,249 +1,70 @@ -import base64 -import re -import hashlib -from collections import OrderedDict +import socket +import sys +import contextlib -from six.moves import socketserver -import pytest -import paho.mqtt.client as client - -from paho.mqtt.client import WebsocketConnectionError -from testsupport.broker import fake_websocket_broker - - -pytestmark = [ - pytest.mark.skipif( - not pytest.config.getoption("--run-integration-tests"), - reason="Specify --run-integration-tests to run these tests", - ), -] - - -@pytest.fixture -def init_response_headers(): - # "Normal" websocket response from server - response_headers = OrderedDict([ - ("Upgrade", "websocket"), - ("Connection", "Upgrade"), - ("Sec-WebSocket-Accept", "testwebsocketkey"), - ("Sec-WebSocket-Protocol", "chat"), - ]) - - return response_headers - - -def get_websocket_response(response_headers): - """ Takes headers and constructs HTTP response - - 'HTTP/1.1 101 Switching Protocols' is the headers for the response, - as expected in client.py - """ - response = "\r\n".join([ - "HTTP/1.1 101 Switching Protocols", - "\r\n".join("{}: {}".format(i, j) for i, j in response_headers.items()), - "\r\n", - ]).encode("utf8") - - return response - - -@pytest.mark.parametrize("proto_ver,proto_name", [ - (client.MQTTv31, "MQIsdp"), - (client.MQTTv311, "MQTT"), -]) -class TestInvalidWebsocketResponse(object): - def test_unexpected_response(self, proto_ver, proto_name, fake_websocket_broker): - """ Server responds with a valid code, but it's not what the client expected """ - - mqttc = client.Client( - "test_unexpected_response", - protocol=proto_ver, - transport="websockets" - ) - - with fake_websocket_broker.serve("200 OK\n"): - with pytest.raises(WebsocketConnectionError) as exc: - mqttc.connect("localhost", 1888, keepalive=10) - - assert str(exc.value) == "WebSocket handshake error" - - -@pytest.mark.parametrize("proto_ver,proto_name", [ - (client.MQTTv31, "MQIsdp"), - (client.MQTTv311, "MQTT"), -]) -class TestBadWebsocketHeaders(object): - """ Testing for basic functionality in checking for headers """ - - def _get_basic_handler(self, response_headers): - """ Get a basic BaseRequestHandler which returns the information in - self._response_headers - """ - - response = get_websocket_response(response_headers) - - class WebsocketHandler(socketserver.BaseRequestHandler): - def handle(_self): - self.data = _self.request.recv(1024).strip() - print("Received '{:s}'".format(self.data.decode("utf8"))) - # Respond with data passed in to serve() - _self.request.sendall(response) - - return WebsocketHandler - - def test_no_upgrade(self, proto_ver, proto_name, fake_websocket_broker, - init_response_headers): - """ Server doesn't respond with 'connection: upgrade' """ - - mqttc = client.Client( - "test_no_upgrade", - protocol=proto_ver, - transport="websockets" - ) - - init_response_headers["Connection"] = "bad" - response = self._get_basic_handler(init_response_headers) - - with fake_websocket_broker.serve(response): - with pytest.raises(WebsocketConnectionError) as exc: - mqttc.connect("localhost", 1888, keepalive=10) - - assert str(exc.value) == "WebSocket handshake error, connection not upgraded" - - def test_bad_secret_key(self, proto_ver, proto_name, fake_websocket_broker, - init_response_headers): - """ Server doesn't give anything after connection: upgrade """ +if sys.version_info < (3, 0): + from mock import patch, Mock +else: + from unittest.mock import patch, Mock - mqttc = client.Client( - "test_bad_secret_key", - protocol=proto_ver, - transport="websockets" - ) - - response = self._get_basic_handler(init_response_headers) - - with fake_websocket_broker.serve(response): - with pytest.raises(WebsocketConnectionError) as exc: - mqttc.connect("localhost", 1888, keepalive=10) - - assert str(exc.value) == "WebSocket handshake error, invalid secret key" - - -@pytest.mark.parametrize("proto_ver,proto_name", [ - (client.MQTTv31, "MQIsdp"), - (client.MQTTv311, "MQTT"), -]) -class TestValidHeaders(object): - """ Testing for functionality in request/response headers """ - - def _get_callback_handler(self, response_headers, check_request=None): - """ Get a basic BaseRequestHandler which returns the information in - self._response_headers - """ - - class WebsocketHandler(socketserver.BaseRequestHandler): - def handle(_self): - self.data = _self.request.recv(1024).strip() - print("Received '{:s}'".format(self.data.decode("utf8"))) - - decoded = self.data.decode("utf8") - - if check_request is not None: - check_request(decoded) - - # Create server hash - GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - key = re.search("sec-websocket-key: ([A-Za-z0-9+/=]*)", decoded, re.IGNORECASE).group(1) - - to_hash = "{:s}{:s}".format(key, GUID) - hashed = hashlib.sha1(to_hash.encode("utf8")) - encoded = base64.b64encode(hashed.digest()).decode("utf8") - - response_headers["Sec-WebSocket-Accept"] = encoded - - # Respond with the correct hash - response = get_websocket_response(response_headers) - - _self.request.sendall(response) - - return WebsocketHandler - - def test_successful_connection(self, proto_ver, proto_name, - fake_websocket_broker, - init_response_headers): - """ Connect successfully, on correct path """ - - mqttc = client.Client( - "test_successful_connection", - protocol=proto_ver, - transport="websockets" - ) - - response = self._get_callback_handler(init_response_headers) - - with fake_websocket_broker.serve(response): - mqttc.connect("localhost", 1888, keepalive=10) - - mqttc.disconnect() - - @pytest.mark.parametrize("mqtt_path", [ - "/mqtt" - "/special", - None, - ]) - def test_correct_path(self, proto_ver, proto_name, fake_websocket_broker, - mqtt_path, init_response_headers): - """ Make sure it can connect on user specified paths """ - - mqttc = client.Client( - "test_correct_path", - protocol=proto_ver, - transport="websockets" - ) +import pytest +from paho.mqtt.client import WebsocketWrapper, WebsocketConnectionError - mqttc.ws_set_options( - path=mqtt_path, - ) - def create_response_hash(decoded): - # Make sure it connects to the right path - assert re.search("GET {:s} HTTP/1.1".format(mqtt_path), decoded, re.IGNORECASE) is not None +class TestHeaders(object): + """ Make sure headers are used correctly """ - response = self._get_callback_handler(init_response_headers) + def test_normal_headers(self): + """ Normal headers as specified in RFC 6455 """ - with fake_websocket_broker.serve(response): - mqttc.connect("localhost", 1888, keepalive=10) + response = [ + "HTTP/1.1 101 Switching Protocols", + "Upgrade: websocket", + "Connection: Upgrade", + "Sec-WebSocket-Accept: badreturnvalue=", + "Sec-WebSocket-Protocol: chat", + "\r\n", + ] - mqttc.disconnect() + def iter_response(): + for i in "\r\n".join(response): + yield i - @pytest.mark.parametrize("auth_headers", [ - {"Authorization": "test123"}, - {"Authorization": "test123", "auth2": "abcdef"}, - # Won't be checked, but make sure it still works even if the user passes it - None, - ]) - def test_correct_auth(self, proto_ver, proto_name, fake_websocket_broker, - auth_headers, init_response_headers): - """ Make sure it sends the right auth headers """ + it = iter_response() - mqttc = client.Client( - "test_correct_path", - protocol=proto_ver, - transport="websockets" - ) + def fakerecv(*args): + return next(it) - mqttc.ws_set_options( - headers=auth_headers, + mocksock = Mock( + spec_set=socket.socket, + recv=fakerecv, + send=Mock(), ) + host = "testhost.com" + port = 1234 + path = "/mqtt" + extra_headers = None + is_ssl = True - def create_response_hash(decoded): - # Make sure it connects to the right path - for h in auth_headers: - assert re.search("{:s}: {:s}".format(h, auth_headers[h]), decoded, re.IGNORECASE) is not None - - response = self._get_callback_handler(init_response_headers) + with pytest.raises(WebsocketConnectionError) as exc: + w = WebsocketWrapper(mocksock, host, port, is_ssl, path, extra_headers) - with fake_websocket_broker.serve(response): - mqttc.connect("localhost", 1888, keepalive=10) + # We're not creating the response hash properly so it should raise this + # error + assert str(exc.value) == "WebSocket handshake error, invalid secret key" - mqttc.disconnect() + expected_sent = [i.format(**locals()) for i in [ + "GET {path:s} HTTP/1.1", + "Host: {host:s}", + "Upgrade: websocket", + "Connection: Upgrade", + "Sec-Websocket-Protocol: mqtt", + "Sec-Websocket-Version: 13", + "Origin: https://{host:s}:{port:d}", + ]] + + # Only sends the header once + assert mocksock.send.call_count == 1 + + for i in expected_sent: + assert i in mocksock.send.call_args[0][0]