From 0a8cccc24e64a6029e391e3f80b114c75766337f Mon Sep 17 00:00:00 2001 From: James Myatt Date: Thu, 29 Sep 2016 22:01:03 +0100 Subject: [PATCH] Remove support for SSL without SSLContext (Fixes #115) Signed-off-by: James Myatt --- src/paho/mqtt/client.py | 260 +++++++----------- test/lib/context.py | 4 +- test/lib/python/08-ssl-bad-cacert.test | 2 +- test/lib/python/08-ssl-connect-cert-auth.test | 2 +- test/lib/python/08-ssl-connect-no-auth.test | 2 +- test/lib/python/08-ssl-fake-cacert.test | 2 +- 6 files changed, 105 insertions(+), 167 deletions(-) diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index da3f56d..41cd522 100755 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -24,10 +24,7 @@ import socket try: import ssl except ImportError: - HAVE_SSL = False ssl = None -else: - HAVE_SSL = True import struct import sys @@ -258,16 +255,6 @@ def _socketpair_compat(): return (sock1, sock2) -def _check_can_read_file(filename): - if filename: - try: - f = open(filename, "r") - except IOError as err: - raise IOError(filename + ": " + err.strerror) - else: - f.close() - - class MQTTMessageInfo: """This is a class returned from Client.publish() and can be used to find out the mid of the message that was published, and to determine whether the @@ -446,6 +433,7 @@ class Client(object): MQTT_LOG_ERR, and MQTT_LOG_DEBUG. The message itself is in buf. """ + def __init__(self, client_id="", clean_session=True, userdata=None, protocol=MQTTv311, transport="tcp"): """client_id is the unique client id string used when connecting to the broker. If client_id is zero length or None, then the behaviour is @@ -546,7 +534,7 @@ class Client(object): self._thread_terminate = False self._ssl = False self._ssl_context = None - self._tls_insecure = False + self._tls_insecure = False # Only used when SSL context does not have check_hostname attribute self._logger = None # No default callbacks self._on_log = None @@ -576,19 +564,15 @@ class Client(object): def tls_set_context(self, context=None): """Configure network encryption and authentication context. Enables SSL/TLS support. - context : an ssl.SSLContext object, or a dictionary containing - arguments for ssl.wrap_socket. By default this is given by + context : an ssl.SSLContext object. By default this is given by `ssl.create_default_context()`, if available. Must be called before connect() or connect_async().""" if self._ssl_context is not None: raise ValueError('SSL/TLS has already been configured.') - if HAVE_SSL is False: - raise ValueError('This platform has no SSL/TLS.') - - if sys.version_info < (2, 7): - raise ValueError('Python 2.7 is the minimum supported version for TLS.') + # Assume that have SSL support, or at least that context input behaves like ssl.SSLContext + # in current versions of Python if context is None: if hasattr(ssl, 'create_default_context'): @@ -599,6 +583,10 @@ class Client(object): self._ssl = True self._ssl_context = context + # Ensure _tls_insecure is consistent with check_hostname attribute + if hasattr(context, 'check_hostname'): + self._tls_insecure = not context.check_hostname + def tls_set(self, ca_certs, certfile=None, keyfile=None, cert_reqs=None, tls_version=None, ciphers=None): """Configure network encryption and authentication options. Enables SSL/TLS support. @@ -634,53 +622,37 @@ class Client(object): more information. Must be called before connect() or connect_async().""" - if HAVE_SSL is False: + if ssl is None: raise ValueError('This platform has no SSL/TLS.') - if sys.version_info < (2, 7): - raise ValueError('Python 2.7 is the minimum supported version for TLS.') + if not hasattr(ssl, 'SSLContext'): + # Require Python version that has SSL context support in standard library + raise ValueError('Python 2.7.9 and 3.2 are the minimum supported versions for TLS.') if ca_certs is None: raise ValueError('ca_certs must not be None.') - # Load defaults - if cert_reqs is None: - cert_reqs = ssl.CERT_REQUIRED + # Create SSLContext object if tls_version is None: tls_version = ssl.PROTOCOL_TLSv1 + context = ssl.SSLContext(tls_version) - if hasattr(ssl, 'SSLContext'): - # Create SSLContext object - context = ssl.SSLContext(tls_version) - - # Configure context - if certfile is not None: - context.load_cert_chain(certfile, keyfile) - if cert_reqs is not None: - context.verify_mode = cert_reqs - if ca_certs is not None: - context.load_verify_locations(ca_certs) - if ciphers is not None: - context.set_ciphers(ciphers) - else: - # Revert to version without SSLContext, since not available - - _check_can_read_file(ca_certs) - _check_can_read_file(certfile) - _check_can_read_file(keyfile) - - # Dictionary of arguments for ssl.wrap_socket - context = { - 'certfile': certfile, - 'keyfile': keyfile, - 'ca_certs': ca_certs, - 'cert_reqs': cert_reqs, - 'ciphers': ciphers, - 'ssl_version': tls_version - } + # Configure context + if certfile is not None: + context.load_cert_chain(certfile, keyfile) + + context.verify_mode = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs + + context.load_verify_locations(ca_certs) + + if ciphers is not None: + context.set_ciphers(ciphers) self.tls_set_context(context) + # Default to secure, sets context.check_hostname attribute if available + self.tls_insecure_set(False) + def tls_insecure_set(self, value): """Configure verification of the server hostname in the server certificate. @@ -693,12 +665,20 @@ class Client(object): Do not use this function in a real system. Setting value to true means there is no point using encryption. - Must be called before connect().""" - if HAVE_SSL is False: - raise ValueError('This platform has no SSL/TLS.') + Must be called before connect() and after either tls_set() or + tls_set_context().""" + + if self._ssl_context is None: + raise ValueError('Must configure SSL context before using tls_insecure_set.') self._tls_insecure = value + # Ensure check_hostname is consistent with _tls_insecure attribute + if hasattr(self._ssl_context, 'check_hostname'): + # Rely on SSLContext to check host name + # If verify_mode is CERT_NONE then the host name will never be checked + self._ssl_context.check_hostname = not value + def enable_logger(self, logger=None): if not logger: if self._logger: @@ -847,21 +827,23 @@ class Client(object): raise if self._ssl: - if isinstance(self._ssl_context, dict): - # Version without SSL Context - sock = ssl.wrap_socket( - sock, **self._ssl_context) + # SSL is only supported when SSLContext is available (implies Python >= 2.7.9 or >= 3.2) + + verify_host = not self._tls_insecure + try: + # Try with server_hostname, even it's not supported in certain scenarios + sock = self._ssl_context.wrap_socket(sock, server_hostname=self._host) + except ValueError: + # Python version requires SNI in order to handle server_hostname, but SNI is not available + sock = self._ssl_context.wrap_socket(sock) else: - # Use SSLContext (implies Python >= 3.2) - server_hostname = self._host if ssl.HAS_SNI else None - sock = self._ssl_context.wrap_socket( - sock, server_hostname=server_hostname) - - if not self._tls_insecure: - if sys.version_info < (2, 7, 9) or (sys.version_info[0] == 3 and sys.version_info[1] < 2): - self._tls_match_hostname(sock) - else: - ssl.match_hostname(sock.getpeercert(), self._host) + # If SSL context has already checked hostname, then don't need to do it again + if (hasattr(self._ssl_context, 'check_hostname') and + self._ssl_context.check_hostname): + verify_host = False + + if verify_host: + ssl.match_hostname(sock.getpeercert(), self._host) if self._transport == "websockets": sock = WebsocketWrapper(sock, self._host, self._port, self._ssl) @@ -1256,7 +1238,7 @@ class Client(object): now = time_func() self._check_keepalive() - if self._last_retry_check+1 < now: + if self._last_retry_check + 1 < now: # Only check once a second at most self._message_retry_check() self._last_retry_check = now @@ -1409,9 +1391,9 @@ class Client(object): # so no other threads can access _current_out_packet, # _out_packet or _messages. if (self._thread_terminate is True - and self._current_out_packet is None - and len(self._out_packet) == 0 - and len(self._out_messages) == 0): + and self._current_out_packet is None + and len(self._out_packet) == 0 + and len(self._out_messages) == 0): rc = 1 run = False @@ -1740,7 +1722,7 @@ class Client(object): byte, = struct.unpack("!B", byte) self._in_packet['remaining_count'].append(byte) # Max 4 bytes length for remaining length as defined by protocol. - # Anything more likely means a broken/malicious client. + # Anything more likely means a broken/malicious client. if len(self._in_packet['remaining_count']) > 4: return MQTT_ERR_PROTOCOL @@ -1907,8 +1889,8 @@ class Client(object): @staticmethod def _topic_wildcard_len_check(topic): # Search for + or # in a topic. Return MQTT_ERR_INVAL if found. - # Also returns MQTT_ERR_INVAL if the topic string is too long. - # Returns MQTT_ERR_SUCCESS if everything is fine. + # Also returns MQTT_ERR_INVAL if the topic string is too long. + # Returns MQTT_ERR_SUCCESS if everything is fine. if b'+' in topic or b'#' in topic or len(topic) == 0 or len(topic) > 65535: return MQTT_ERR_INVAL else: @@ -1970,12 +1952,12 @@ class Client(object): if self._sock is None: return MQTT_ERR_NO_CONN - command = PUBLISH | ((dup&0x1)<<3) | (qos<<1) | retain + command = PUBLISH | ((dup & 0x1) << 3) | (qos << 1) | retain packet = bytearray() packet.append(command) payloadlen = len(payload) - remaining_length = 2+len(topic) + payloadlen + remaining_length = 2 + len(topic) + payloadlen if payloadlen == 0: self._easy_log( @@ -2011,7 +1993,7 @@ class Client(object): def _send_pubrel(self, mid, dup=False): self._easy_log(MQTT_LOG_DEBUG, "Sending PUBREL (Mid: %d)", mid) - return self._send_command_with_mid(PUBREL|2, mid, dup) + return self._send_command_with_mid(PUBREL | 2, mid, dup) def _send_command_with_mid(self, command, mid, dup): # For PUBACK, PUBCOMP, PUBREC, and PUBREL @@ -2037,21 +2019,21 @@ class Client(object): proto_ver = 4 protocol = protocol.encode('utf-8') - remaining_length = 2+len(protocol) + 1+1+2 + 2+len(self._client_id) + remaining_length = 2 + len(protocol) + 1 + 1 + 2 + 2 + len(self._client_id) connect_flags = 0 if clean_session: connect_flags |= 0x02 if self._will: - remaining_length += 2+len(self._will_topic) + 2+len(self._will_payload) - connect_flags |= 0x04 | ((self._will_qos&0x03) << 3) | ((self._will_retain&0x01) << 5) + remaining_length += 2 + len(self._will_topic) + 2 + len(self._will_payload) + connect_flags |= 0x04 | ((self._will_qos & 0x03) << 3) | ((self._will_retain & 0x01) << 5) if self._username is not None: - remaining_length += 2+len(self._username) + remaining_length += 2 + len(self._username) connect_flags |= 0x80 if self._password is not None: connect_flags |= 0x40 - remaining_length += 2+len(self._password) + remaining_length += 2 + len(self._password) command = CONNECT packet = bytearray() @@ -2094,9 +2076,9 @@ class Client(object): def _send_subscribe(self, dup, topics): remaining_length = 2 for t, _ in topics: - remaining_length += 2+len(t)+1 + remaining_length += 2 + len(t) + 1 - command = SUBSCRIBE | (dup<<3) | 0x2 + command = SUBSCRIBE | (dup << 3) | 0x2 packet = bytearray() packet.append(command) self._pack_remaining_length(packet, remaining_length) @@ -2110,9 +2092,9 @@ class Client(object): def _send_unsubscribe(self, dup, topics): remaining_length = 2 for t in topics: - remaining_length += 2+len(t) + remaining_length += 2 + len(t) - command = UNSUBSCRIBE | (dup<<3) | 0x2 + command = UNSUBSCRIBE | (dup << 3) | 0x2 packet = bytearray() packet.append(command) self._pack_remaining_length(packet, remaining_length) @@ -2121,7 +2103,7 @@ class Client(object): for t in topics: self._pack_str16(packet, t) - #topics_repr = ", ".join("'"+topic.decode('utf8')+"'" for topic in topics) + # topics_repr = ", ".join("'"+topic.decode('utf8')+"'" for topic in topics) self._easy_log(MQTT_LOG_DEBUG, "Sending UNSUBSCRIBE (d%d) %s", dup, topics) return (self._packet_queue(command, packet, local_mid, 1), local_mid) @@ -2157,12 +2139,12 @@ class Client(object): if m.qos == 0: m.state = mqtt_ms_publish elif m.qos == 1: - #self._inflight_messages = self._inflight_messages + 1 + # self._inflight_messages = self._inflight_messages + 1 if m.state == mqtt_ms_wait_for_puback: m.dup = True m.state = mqtt_ms_publish elif m.qos == 2: - #self._inflight_messages = self._inflight_messages + 1 + # self._inflight_messages = self._inflight_messages + 1 if m.state == mqtt_ms_wait_for_pubcomp: m.state = mqtt_ms_resend_pubrel m.dup = True @@ -2311,12 +2293,12 @@ class Client(object): for m in self._out_messages: m.timestamp = time_func() if m.state == mqtt_ms_queued: - self.loop_write() # Process outgoing messages that have just been queued up + self.loop_write() # Process outgoing messages that have just been queued up self._out_message_mutex.release() return MQTT_ERR_SUCCESS if m.qos == 0: - self._in_callback = True # Don't call loop_write after _send_publish() + self._in_callback = True # Don't call loop_write after _send_publish() rc = self._send_publish(m.mid, m.topic, m.payload, m.qos, m.retain, m.dup) self._in_callback = False if rc != 0: @@ -2326,7 +2308,7 @@ class Client(object): if m.state == mqtt_ms_publish: self._inflight_messages += 1 m.state = mqtt_ms_wait_for_puback - self._in_callback = True # Don't call loop_write after _send_publish() + self._in_callback = True # Don't call loop_write after _send_publish() rc = self._send_publish(m.mid, m.topic, m.payload, m.qos, m.retain, m.dup) self._in_callback = False if rc != 0: @@ -2336,7 +2318,7 @@ class Client(object): if m.state == mqtt_ms_publish: self._inflight_messages += 1 m.state = mqtt_ms_wait_for_pubrec - self._in_callback = True # Don't call loop_write after _send_publish() + self._in_callback = True # Don't call loop_write after _send_publish() rc = self._send_publish(m.mid, m.topic, m.payload, m.qos, m.retain, m.dup) self._in_callback = False if rc != 0: @@ -2345,13 +2327,13 @@ class Client(object): elif m.state == mqtt_ms_resend_pubrel: self._inflight_messages += 1 m.state = mqtt_ms_wait_for_pubcomp - self._in_callback = True # Don't call loop_write after _send_pubrel() + self._in_callback = True # Don't call loop_write after _send_pubrel() rc = self._send_pubrel(m.mid, m.dup) self._in_callback = False if rc != 0: self._out_message_mutex.release() return rc - self.loop_write() # Process outgoing messages that have just been queued up + self.loop_write() # Process outgoing messages that have just been queued up self._out_message_mutex.release() return rc elif result > 0 and result < 6: @@ -2361,9 +2343,9 @@ class Client(object): def _handle_suback(self): self._easy_log(MQTT_LOG_DEBUG, "Received SUBACK") - pack_format = "!H" + str(len(self._in_packet['packet'])-2) + 's' + pack_format = "!H" + str(len(self._in_packet['packet']) - 2) + 's' (mid, packet) = struct.unpack(pack_format, self._in_packet['packet']) - pack_format = "!" + "B"*len(packet) + pack_format = "!" + "B" * len(packet) granted_qos = struct.unpack(pack_format, packet) self._callback_mutex.acquire() @@ -2380,13 +2362,13 @@ class Client(object): header = self._in_packet['command'] message = MQTTMessage() - message.dup = (header & 0x08)>>3 - message.qos = (header & 0x06)>>1 + message.dup = (header & 0x08) >> 3 + message.qos = (header & 0x06) >> 1 message.retain = (header & 0x01) - pack_format = "!H" + str(len(self._in_packet['packet'])-2) + 's' + pack_format = "!H" + str(len(self._in_packet['packet']) - 2) + 's' (slen, packet) = struct.unpack(pack_format, self._in_packet['packet']) - pack_format = '!' + str(slen) + 's' + str(len(packet)-slen) + 's' + pack_format = '!' + str(slen) + 's' + str(len(packet) - slen) + 's' (message.topic, packet) = struct.unpack(pack_format, packet) if len(message.topic) == 0: @@ -2396,7 +2378,7 @@ class Client(object): message.topic = message.topic.decode('utf-8') if message.qos > 0: - pack_format = "!H" + str(len(packet)-2) + 's' + pack_format = "!H" + str(len(packet) - 2) + 's' (message.mid, packet) = struct.unpack(pack_format, packet) message.payload = packet @@ -2573,58 +2555,14 @@ class Client(object): def _thread_main(self): self.loop_forever(retry_first_connection=True) - def _host_matches_cert(self, host, cert_host): - if cert_host[0:2] == "*.": - if cert_host.count("*") != 1: - return False - - host_match = host.split(".", 1)[1] - cert_match = cert_host.split(".", 1)[1] - return host_match == cert_match - else: - return host == cert_host - - def _tls_match_hostname(self, sock): - try: - cert = sock.getpeercert() - except AttributeError: - # the getpeercert can throw Attribute error: object has no attribute 'peer_certificate' - # Don't let that crash the whole client. See also: http://bugs.python.org/issue13721 - raise ssl.SSLError('Not connected') - - san = cert.get('subjectAltName') - if san: - have_san_dns = False - for (key, value) in san: - if key == 'DNS': - have_san_dns = True - if self._host_matches_cert(self._host.lower(), value.lower()): - return - if key == 'IP Address': - have_san_dns = True - if value.lower() == self._host.lower(): - return - - if have_san_dns: - # Only check subject if subjectAltName dns not found. - raise ssl.SSLError('Certificate subject does not match remote hostname.') - subject = cert.get('subject') - if subject: - for ((key, value),) in subject: - if key == 'commonName': - if self._host_matches_cert(self._host.lower(), value.lower()): - return - - raise ssl.SSLError('Certificate subject does not match remote hostname.') - # Compatibility class for easy porting from mosquitto.py. class Mosquitto(Client): def __init__(self, client_id="", clean_session=True, userdata=None): super(Mosquitto, self).__init__(client_id, clean_session, userdata) -class WebsocketWrapper: +class WebsocketWrapper: OPCODE_CONTINUATION = 0x0 OPCODE_TEXT = 0x1 OPCODE_BINARY = 0x2 @@ -2660,13 +2598,13 @@ class WebsocketWrapper: sec_websocket_key = uuid.uuid4().bytes sec_websocket_key = base64.b64encode(sec_websocket_key) - header = b"GET /mqtt HTTP/1.1\r\n" +\ - b"Upgrade: websocket\r\n" +\ - b"Connection: Upgrade\r\n" +\ - b"Host: " + str(self._host).encode('utf-8') + b":" + str(self._port).encode('utf-8') + b"\r\n" +\ + header = b"GET /mqtt HTTP/1.1\r\n" + \ + b"Upgrade: websocket\r\n" + \ + b"Connection: Upgrade\r\n" + \ + b"Host: " + str(self._host).encode('utf-8') + b":" + str(self._port).encode('utf-8') + b"\r\n" + \ b"Origin: http://" + str(self._host).encode('utf-8') + b":" + str(self._port).encode('utf-8') + b"\r\n" +\ - b"Sec-WebSocket-Key: " + sec_websocket_key + b"\r\n" +\ - b"Sec-WebSocket-Version: 13\r\n" +\ + b"Sec-WebSocket-Key: " + sec_websocket_key + b"\r\n" + \ + b"Sec-WebSocket-Version: 13\r\n" + \ b"Sec-WebSocket-Protocol: mqtt\r\n\r\n" self._socket.send(header) @@ -2765,7 +2703,7 @@ class WebsocketWrapper: self._readbuffer.extend(data) self._readbuffer_head += length - return self._readbuffer[self._readbuffer_head-length:self._readbuffer_head] + return self._readbuffer[self._readbuffer_head - length:self._readbuffer_head] def _recv_impl(self, length): @@ -2890,5 +2828,5 @@ class WebsocketWrapper: def fileno(self): return self._socket.fileno() - def setblocking(self,flag): + def setblocking(self, flag): self._socket.setblocking(flag) diff --git a/test/lib/context.py b/test/lib/context.py index c90d73b..29750b9 100644 --- a/test/lib/context.py +++ b/test/lib/context.py @@ -52,6 +52,6 @@ def check_ssl(): print("WARNING: SSL not available in current environment") exit(0) - if sys.version < '2.7': - print("WARNING: SSL not supported on Python 2.6") + if not hasattr(ssl, 'SSLContext'): + print("WARNING: SSL without SSLContext is not supported") exit(0) diff --git a/test/lib/python/08-ssl-bad-cacert.test b/test/lib/python/08-ssl-bad-cacert.test index 42da4fd..5e36c18 100755 --- a/test/lib/python/08-ssl-bad-cacert.test +++ b/test/lib/python/08-ssl-bad-cacert.test @@ -9,7 +9,7 @@ from struct import * import paho.mqtt.client as mqtt -if sys.version < '2.7': +if sys.version_info < (2, 7, 9): print("WARNING: SSL/TLS not supported on Python 2.6") exit(0) diff --git a/test/lib/python/08-ssl-connect-cert-auth.test b/test/lib/python/08-ssl-connect-cert-auth.test index 4450f56..f5e8cea 100755 --- a/test/lib/python/08-ssl-connect-cert-auth.test +++ b/test/lib/python/08-ssl-connect-cert-auth.test @@ -9,7 +9,7 @@ from struct import * import paho.mqtt.client as mqtt -if sys.version_info < (2, 7): +if sys.version_info < (2, 7, 9): print("WARNING: SSL/TLS not supported on Python 2.6") exit(0) diff --git a/test/lib/python/08-ssl-connect-no-auth.test b/test/lib/python/08-ssl-connect-no-auth.test index 6810294..5398b04 100755 --- a/test/lib/python/08-ssl-connect-no-auth.test +++ b/test/lib/python/08-ssl-connect-no-auth.test @@ -9,7 +9,7 @@ from struct import * import paho.mqtt.client as mqtt -if sys.version_info < (2, 7): +if sys.version_info < (2, 7, 9): print("WARNING: SSL/TLS not supported on Python 2.6") exit(0) diff --git a/test/lib/python/08-ssl-fake-cacert.test b/test/lib/python/08-ssl-fake-cacert.test index be7c21c..d671225 100755 --- a/test/lib/python/08-ssl-fake-cacert.test +++ b/test/lib/python/08-ssl-fake-cacert.test @@ -10,7 +10,7 @@ import ssl import paho.mqtt.client as mqtt -if sys.version_info < (2, 7): +if sys.version_info < (2, 7, 9): print("WARNING: SSL/TLS not supported on Python 2.6") exit(0) -- GitLab