提交 0a8cccc2 编写于 作者: J James Myatt

Remove support for SSL without SSLContext (Fixes #115)

Signed-off-by: NJames Myatt <james@jamesmyatt.co.uk>
上级 37cf9537
......@@ -24,10 +24,7 @@ import socket
try:
import ssl
except ImportError:
HAVE_SSL = False
ssl = None
else:
HAVE_SSL = True
import struct
import sys
......@@ -258,16 +255,6 @@ def _socketpair_compat():
return (sock1, sock2)
def _check_can_read_file(filename):
if filename:
try:
f = open(filename, "r")
except IOError as err:
raise IOError(filename + ": " + err.strerror)
else:
f.close()
class MQTTMessageInfo:
"""This is a class returned from Client.publish() and can be used to find
out the mid of the message that was published, and to determine whether the
......@@ -446,6 +433,7 @@ class Client(object):
MQTT_LOG_ERR, and MQTT_LOG_DEBUG. The message itself is in buf.
"""
def __init__(self, client_id="", clean_session=True, userdata=None, protocol=MQTTv311, transport="tcp"):
"""client_id is the unique client id string used when connecting to the
broker. If client_id is zero length or None, then the behaviour is
......@@ -546,7 +534,7 @@ class Client(object):
self._thread_terminate = False
self._ssl = False
self._ssl_context = None
self._tls_insecure = False
self._tls_insecure = False # Only used when SSL context does not have check_hostname attribute
self._logger = None
# No default callbacks
self._on_log = None
......@@ -576,19 +564,15 @@ class Client(object):
def tls_set_context(self, context=None):
"""Configure network encryption and authentication context. Enables SSL/TLS support.
context : an ssl.SSLContext object, or a dictionary containing
arguments for ssl.wrap_socket. By default this is given by
context : an ssl.SSLContext object. By default this is given by
`ssl.create_default_context()`, if available.
Must be called before connect() or connect_async()."""
if self._ssl_context is not None:
raise ValueError('SSL/TLS has already been configured.')
if HAVE_SSL is False:
raise ValueError('This platform has no SSL/TLS.')
if sys.version_info < (2, 7):
raise ValueError('Python 2.7 is the minimum supported version for TLS.')
# Assume that have SSL support, or at least that context input behaves like ssl.SSLContext
# in current versions of Python
if context is None:
if hasattr(ssl, 'create_default_context'):
......@@ -599,6 +583,10 @@ class Client(object):
self._ssl = True
self._ssl_context = context
# Ensure _tls_insecure is consistent with check_hostname attribute
if hasattr(context, 'check_hostname'):
self._tls_insecure = not context.check_hostname
def tls_set(self, ca_certs, certfile=None, keyfile=None, cert_reqs=None, tls_version=None, ciphers=None):
"""Configure network encryption and authentication options. Enables SSL/TLS support.
......@@ -634,53 +622,37 @@ class Client(object):
more information.
Must be called before connect() or connect_async()."""
if HAVE_SSL is False:
if ssl is None:
raise ValueError('This platform has no SSL/TLS.')
if sys.version_info < (2, 7):
raise ValueError('Python 2.7 is the minimum supported version for TLS.')
if not hasattr(ssl, 'SSLContext'):
# Require Python version that has SSL context support in standard library
raise ValueError('Python 2.7.9 and 3.2 are the minimum supported versions for TLS.')
if ca_certs is None:
raise ValueError('ca_certs must not be None.')
# Load defaults
if cert_reqs is None:
cert_reqs = ssl.CERT_REQUIRED
# Create SSLContext object
if tls_version is None:
tls_version = ssl.PROTOCOL_TLSv1
context = ssl.SSLContext(tls_version)
if hasattr(ssl, 'SSLContext'):
# Create SSLContext object
context = ssl.SSLContext(tls_version)
# Configure context
if certfile is not None:
context.load_cert_chain(certfile, keyfile)
if cert_reqs is not None:
context.verify_mode = cert_reqs
if ca_certs is not None:
context.load_verify_locations(ca_certs)
if ciphers is not None:
context.set_ciphers(ciphers)
else:
# Revert to version without SSLContext, since not available
_check_can_read_file(ca_certs)
_check_can_read_file(certfile)
_check_can_read_file(keyfile)
# Dictionary of arguments for ssl.wrap_socket
context = {
'certfile': certfile,
'keyfile': keyfile,
'ca_certs': ca_certs,
'cert_reqs': cert_reqs,
'ciphers': ciphers,
'ssl_version': tls_version
}
# Configure context
if certfile is not None:
context.load_cert_chain(certfile, keyfile)
context.verify_mode = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs
context.load_verify_locations(ca_certs)
if ciphers is not None:
context.set_ciphers(ciphers)
self.tls_set_context(context)
# Default to secure, sets context.check_hostname attribute if available
self.tls_insecure_set(False)
def tls_insecure_set(self, value):
"""Configure verification of the server hostname in the server certificate.
......@@ -693,12 +665,20 @@ class Client(object):
Do not use this function in a real system. Setting value to true means
there is no point using encryption.
Must be called before connect()."""
if HAVE_SSL is False:
raise ValueError('This platform has no SSL/TLS.')
Must be called before connect() and after either tls_set() or
tls_set_context()."""
if self._ssl_context is None:
raise ValueError('Must configure SSL context before using tls_insecure_set.')
self._tls_insecure = value
# Ensure check_hostname is consistent with _tls_insecure attribute
if hasattr(self._ssl_context, 'check_hostname'):
# Rely on SSLContext to check host name
# If verify_mode is CERT_NONE then the host name will never be checked
self._ssl_context.check_hostname = not value
def enable_logger(self, logger=None):
if not logger:
if self._logger:
......@@ -847,21 +827,23 @@ class Client(object):
raise
if self._ssl:
if isinstance(self._ssl_context, dict):
# Version without SSL Context
sock = ssl.wrap_socket(
sock, **self._ssl_context)
# SSL is only supported when SSLContext is available (implies Python >= 2.7.9 or >= 3.2)
verify_host = not self._tls_insecure
try:
# Try with server_hostname, even it's not supported in certain scenarios
sock = self._ssl_context.wrap_socket(sock, server_hostname=self._host)
except ValueError:
# Python version requires SNI in order to handle server_hostname, but SNI is not available
sock = self._ssl_context.wrap_socket(sock)
else:
# Use SSLContext (implies Python >= 3.2)
server_hostname = self._host if ssl.HAS_SNI else None
sock = self._ssl_context.wrap_socket(
sock, server_hostname=server_hostname)
if not self._tls_insecure:
if sys.version_info < (2, 7, 9) or (sys.version_info[0] == 3 and sys.version_info[1] < 2):
self._tls_match_hostname(sock)
else:
ssl.match_hostname(sock.getpeercert(), self._host)
# If SSL context has already checked hostname, then don't need to do it again
if (hasattr(self._ssl_context, 'check_hostname') and
self._ssl_context.check_hostname):
verify_host = False
if verify_host:
ssl.match_hostname(sock.getpeercert(), self._host)
if self._transport == "websockets":
sock = WebsocketWrapper(sock, self._host, self._port, self._ssl)
......@@ -1256,7 +1238,7 @@ class Client(object):
now = time_func()
self._check_keepalive()
if self._last_retry_check+1 < now:
if self._last_retry_check + 1 < now:
# Only check once a second at most
self._message_retry_check()
self._last_retry_check = now
......@@ -1409,9 +1391,9 @@ class Client(object):
# so no other threads can access _current_out_packet,
# _out_packet or _messages.
if (self._thread_terminate is True
and self._current_out_packet is None
and len(self._out_packet) == 0
and len(self._out_messages) == 0):
and self._current_out_packet is None
and len(self._out_packet) == 0
and len(self._out_messages) == 0):
rc = 1
run = False
......@@ -1740,7 +1722,7 @@ class Client(object):
byte, = struct.unpack("!B", byte)
self._in_packet['remaining_count'].append(byte)
# Max 4 bytes length for remaining length as defined by protocol.
# Anything more likely means a broken/malicious client.
# Anything more likely means a broken/malicious client.
if len(self._in_packet['remaining_count']) > 4:
return MQTT_ERR_PROTOCOL
......@@ -1907,8 +1889,8 @@ class Client(object):
@staticmethod
def _topic_wildcard_len_check(topic):
# Search for + or # in a topic. Return MQTT_ERR_INVAL if found.
# Also returns MQTT_ERR_INVAL if the topic string is too long.
# Returns MQTT_ERR_SUCCESS if everything is fine.
# Also returns MQTT_ERR_INVAL if the topic string is too long.
# Returns MQTT_ERR_SUCCESS if everything is fine.
if b'+' in topic or b'#' in topic or len(topic) == 0 or len(topic) > 65535:
return MQTT_ERR_INVAL
else:
......@@ -1970,12 +1952,12 @@ class Client(object):
if self._sock is None:
return MQTT_ERR_NO_CONN
command = PUBLISH | ((dup&0x1)<<3) | (qos<<1) | retain
command = PUBLISH | ((dup & 0x1) << 3) | (qos << 1) | retain
packet = bytearray()
packet.append(command)
payloadlen = len(payload)
remaining_length = 2+len(topic) + payloadlen
remaining_length = 2 + len(topic) + payloadlen
if payloadlen == 0:
self._easy_log(
......@@ -2011,7 +1993,7 @@ class Client(object):
def _send_pubrel(self, mid, dup=False):
self._easy_log(MQTT_LOG_DEBUG, "Sending PUBREL (Mid: %d)", mid)
return self._send_command_with_mid(PUBREL|2, mid, dup)
return self._send_command_with_mid(PUBREL | 2, mid, dup)
def _send_command_with_mid(self, command, mid, dup):
# For PUBACK, PUBCOMP, PUBREC, and PUBREL
......@@ -2037,21 +2019,21 @@ class Client(object):
proto_ver = 4
protocol = protocol.encode('utf-8')
remaining_length = 2+len(protocol) + 1+1+2 + 2+len(self._client_id)
remaining_length = 2 + len(protocol) + 1 + 1 + 2 + 2 + len(self._client_id)
connect_flags = 0
if clean_session:
connect_flags |= 0x02
if self._will:
remaining_length += 2+len(self._will_topic) + 2+len(self._will_payload)
connect_flags |= 0x04 | ((self._will_qos&0x03) << 3) | ((self._will_retain&0x01) << 5)
remaining_length += 2 + len(self._will_topic) + 2 + len(self._will_payload)
connect_flags |= 0x04 | ((self._will_qos & 0x03) << 3) | ((self._will_retain & 0x01) << 5)
if self._username is not None:
remaining_length += 2+len(self._username)
remaining_length += 2 + len(self._username)
connect_flags |= 0x80
if self._password is not None:
connect_flags |= 0x40
remaining_length += 2+len(self._password)
remaining_length += 2 + len(self._password)
command = CONNECT
packet = bytearray()
......@@ -2094,9 +2076,9 @@ class Client(object):
def _send_subscribe(self, dup, topics):
remaining_length = 2
for t, _ in topics:
remaining_length += 2+len(t)+1
remaining_length += 2 + len(t) + 1
command = SUBSCRIBE | (dup<<3) | 0x2
command = SUBSCRIBE | (dup << 3) | 0x2
packet = bytearray()
packet.append(command)
self._pack_remaining_length(packet, remaining_length)
......@@ -2110,9 +2092,9 @@ class Client(object):
def _send_unsubscribe(self, dup, topics):
remaining_length = 2
for t in topics:
remaining_length += 2+len(t)
remaining_length += 2 + len(t)
command = UNSUBSCRIBE | (dup<<3) | 0x2
command = UNSUBSCRIBE | (dup << 3) | 0x2
packet = bytearray()
packet.append(command)
self._pack_remaining_length(packet, remaining_length)
......@@ -2121,7 +2103,7 @@ class Client(object):
for t in topics:
self._pack_str16(packet, t)
#topics_repr = ", ".join("'"+topic.decode('utf8')+"'" for topic in topics)
# topics_repr = ", ".join("'"+topic.decode('utf8')+"'" for topic in topics)
self._easy_log(MQTT_LOG_DEBUG, "Sending UNSUBSCRIBE (d%d) %s", dup, topics)
return (self._packet_queue(command, packet, local_mid, 1), local_mid)
......@@ -2157,12 +2139,12 @@ class Client(object):
if m.qos == 0:
m.state = mqtt_ms_publish
elif m.qos == 1:
#self._inflight_messages = self._inflight_messages + 1
# self._inflight_messages = self._inflight_messages + 1
if m.state == mqtt_ms_wait_for_puback:
m.dup = True
m.state = mqtt_ms_publish
elif m.qos == 2:
#self._inflight_messages = self._inflight_messages + 1
# self._inflight_messages = self._inflight_messages + 1
if m.state == mqtt_ms_wait_for_pubcomp:
m.state = mqtt_ms_resend_pubrel
m.dup = True
......@@ -2311,12 +2293,12 @@ class Client(object):
for m in self._out_messages:
m.timestamp = time_func()
if m.state == mqtt_ms_queued:
self.loop_write() # Process outgoing messages that have just been queued up
self.loop_write() # Process outgoing messages that have just been queued up
self._out_message_mutex.release()
return MQTT_ERR_SUCCESS
if m.qos == 0:
self._in_callback = True # Don't call loop_write after _send_publish()
self._in_callback = True # Don't call loop_write after _send_publish()
rc = self._send_publish(m.mid, m.topic, m.payload, m.qos, m.retain, m.dup)
self._in_callback = False
if rc != 0:
......@@ -2326,7 +2308,7 @@ class Client(object):
if m.state == mqtt_ms_publish:
self._inflight_messages += 1
m.state = mqtt_ms_wait_for_puback
self._in_callback = True # Don't call loop_write after _send_publish()
self._in_callback = True # Don't call loop_write after _send_publish()
rc = self._send_publish(m.mid, m.topic, m.payload, m.qos, m.retain, m.dup)
self._in_callback = False
if rc != 0:
......@@ -2336,7 +2318,7 @@ class Client(object):
if m.state == mqtt_ms_publish:
self._inflight_messages += 1
m.state = mqtt_ms_wait_for_pubrec
self._in_callback = True # Don't call loop_write after _send_publish()
self._in_callback = True # Don't call loop_write after _send_publish()
rc = self._send_publish(m.mid, m.topic, m.payload, m.qos, m.retain, m.dup)
self._in_callback = False
if rc != 0:
......@@ -2345,13 +2327,13 @@ class Client(object):
elif m.state == mqtt_ms_resend_pubrel:
self._inflight_messages += 1
m.state = mqtt_ms_wait_for_pubcomp
self._in_callback = True # Don't call loop_write after _send_pubrel()
self._in_callback = True # Don't call loop_write after _send_pubrel()
rc = self._send_pubrel(m.mid, m.dup)
self._in_callback = False
if rc != 0:
self._out_message_mutex.release()
return rc
self.loop_write() # Process outgoing messages that have just been queued up
self.loop_write() # Process outgoing messages that have just been queued up
self._out_message_mutex.release()
return rc
elif result > 0 and result < 6:
......@@ -2361,9 +2343,9 @@ class Client(object):
def _handle_suback(self):
self._easy_log(MQTT_LOG_DEBUG, "Received SUBACK")
pack_format = "!H" + str(len(self._in_packet['packet'])-2) + 's'
pack_format = "!H" + str(len(self._in_packet['packet']) - 2) + 's'
(mid, packet) = struct.unpack(pack_format, self._in_packet['packet'])
pack_format = "!" + "B"*len(packet)
pack_format = "!" + "B" * len(packet)
granted_qos = struct.unpack(pack_format, packet)
self._callback_mutex.acquire()
......@@ -2380,13 +2362,13 @@ class Client(object):
header = self._in_packet['command']
message = MQTTMessage()
message.dup = (header & 0x08)>>3
message.qos = (header & 0x06)>>1
message.dup = (header & 0x08) >> 3
message.qos = (header & 0x06) >> 1
message.retain = (header & 0x01)
pack_format = "!H" + str(len(self._in_packet['packet'])-2) + 's'
pack_format = "!H" + str(len(self._in_packet['packet']) - 2) + 's'
(slen, packet) = struct.unpack(pack_format, self._in_packet['packet'])
pack_format = '!' + str(slen) + 's' + str(len(packet)-slen) + 's'
pack_format = '!' + str(slen) + 's' + str(len(packet) - slen) + 's'
(message.topic, packet) = struct.unpack(pack_format, packet)
if len(message.topic) == 0:
......@@ -2396,7 +2378,7 @@ class Client(object):
message.topic = message.topic.decode('utf-8')
if message.qos > 0:
pack_format = "!H" + str(len(packet)-2) + 's'
pack_format = "!H" + str(len(packet) - 2) + 's'
(message.mid, packet) = struct.unpack(pack_format, packet)
message.payload = packet
......@@ -2573,58 +2555,14 @@ class Client(object):
def _thread_main(self):
self.loop_forever(retry_first_connection=True)
def _host_matches_cert(self, host, cert_host):
if cert_host[0:2] == "*.":
if cert_host.count("*") != 1:
return False
host_match = host.split(".", 1)[1]
cert_match = cert_host.split(".", 1)[1]
return host_match == cert_match
else:
return host == cert_host
def _tls_match_hostname(self, sock):
try:
cert = sock.getpeercert()
except AttributeError:
# the getpeercert can throw Attribute error: object has no attribute 'peer_certificate'
# Don't let that crash the whole client. See also: http://bugs.python.org/issue13721
raise ssl.SSLError('Not connected')
san = cert.get('subjectAltName')
if san:
have_san_dns = False
for (key, value) in san:
if key == 'DNS':
have_san_dns = True
if self._host_matches_cert(self._host.lower(), value.lower()):
return
if key == 'IP Address':
have_san_dns = True
if value.lower() == self._host.lower():
return
if have_san_dns:
# Only check subject if subjectAltName dns not found.
raise ssl.SSLError('Certificate subject does not match remote hostname.')
subject = cert.get('subject')
if subject:
for ((key, value),) in subject:
if key == 'commonName':
if self._host_matches_cert(self._host.lower(), value.lower()):
return
raise ssl.SSLError('Certificate subject does not match remote hostname.')
# Compatibility class for easy porting from mosquitto.py.
class Mosquitto(Client):
def __init__(self, client_id="", clean_session=True, userdata=None):
super(Mosquitto, self).__init__(client_id, clean_session, userdata)
class WebsocketWrapper:
class WebsocketWrapper:
OPCODE_CONTINUATION = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
......@@ -2660,13 +2598,13 @@ class WebsocketWrapper:
sec_websocket_key = uuid.uuid4().bytes
sec_websocket_key = base64.b64encode(sec_websocket_key)
header = b"GET /mqtt HTTP/1.1\r\n" +\
b"Upgrade: websocket\r\n" +\
b"Connection: Upgrade\r\n" +\
b"Host: " + str(self._host).encode('utf-8') + b":" + str(self._port).encode('utf-8') + b"\r\n" +\
header = b"GET /mqtt HTTP/1.1\r\n" + \
b"Upgrade: websocket\r\n" + \
b"Connection: Upgrade\r\n" + \
b"Host: " + str(self._host).encode('utf-8') + b":" + str(self._port).encode('utf-8') + b"\r\n" + \
b"Origin: http://" + str(self._host).encode('utf-8') + b":" + str(self._port).encode('utf-8') + b"\r\n" +\
b"Sec-WebSocket-Key: " + sec_websocket_key + b"\r\n" +\
b"Sec-WebSocket-Version: 13\r\n" +\
b"Sec-WebSocket-Key: " + sec_websocket_key + b"\r\n" + \
b"Sec-WebSocket-Version: 13\r\n" + \
b"Sec-WebSocket-Protocol: mqtt\r\n\r\n"
self._socket.send(header)
......@@ -2765,7 +2703,7 @@ class WebsocketWrapper:
self._readbuffer.extend(data)
self._readbuffer_head += length
return self._readbuffer[self._readbuffer_head-length:self._readbuffer_head]
return self._readbuffer[self._readbuffer_head - length:self._readbuffer_head]
def _recv_impl(self, length):
......@@ -2890,5 +2828,5 @@ class WebsocketWrapper:
def fileno(self):
return self._socket.fileno()
def setblocking(self,flag):
def setblocking(self, flag):
self._socket.setblocking(flag)
......@@ -52,6 +52,6 @@ def check_ssl():
print("WARNING: SSL not available in current environment")
exit(0)
if sys.version < '2.7':
print("WARNING: SSL not supported on Python 2.6")
if not hasattr(ssl, 'SSLContext'):
print("WARNING: SSL without SSLContext is not supported")
exit(0)
......@@ -9,7 +9,7 @@ from struct import *
import paho.mqtt.client as mqtt
if sys.version < '2.7':
if sys.version_info < (2, 7, 9):
print("WARNING: SSL/TLS not supported on Python 2.6")
exit(0)
......
......@@ -9,7 +9,7 @@ from struct import *
import paho.mqtt.client as mqtt
if sys.version_info < (2, 7):
if sys.version_info < (2, 7, 9):
print("WARNING: SSL/TLS not supported on Python 2.6")
exit(0)
......
......@@ -9,7 +9,7 @@ from struct import *
import paho.mqtt.client as mqtt
if sys.version_info < (2, 7):
if sys.version_info < (2, 7, 9):
print("WARNING: SSL/TLS not supported on Python 2.6")
exit(0)
......
......@@ -10,7 +10,7 @@ import ssl
import paho.mqtt.client as mqtt
if sys.version_info < (2, 7):
if sys.version_info < (2, 7, 9):
print("WARNING: SSL/TLS not supported on Python 2.6")
exit(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册