From 7818fd3363d9425a1111b3f6baa885fb1f8855b3 Mon Sep 17 00:00:00 2001 From: Krzysiek S Date: Mon, 6 Feb 2023 20:11:22 +0100 Subject: [PATCH] Allow passing IP as connect method parameter in WiFiClientSecure and skip unnecessary host-ip conversions (#7643) --- .../WiFiClientSecure/src/WiFiClientSecure.cpp | 20 ++++++++++++++++--- .../WiFiClientSecure/src/WiFiClientSecure.h | 1 + libraries/WiFiClientSecure/src/ssl_client.cpp | 11 +++------- libraries/WiFiClientSecure/src/ssl_client.h | 2 +- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp b/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp index a6814b0a6..b46fd15e3 100644 --- a/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp +++ b/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp @@ -124,12 +124,21 @@ int WiFiClientSecure::connect(const char *host, uint16_t port, int32_t timeout){ int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *CA_cert, const char *cert, const char *private_key) { - return connect(ip.toString().c_str(), port, CA_cert, cert, private_key); + return connect(ip, port, NULL, CA_cert, cert, private_key); } int WiFiClientSecure::connect(const char *host, uint16_t port, const char *CA_cert, const char *cert, const char *private_key) { - int ret = start_ssl_client(sslclient, host, port, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos); + IPAddress address; + if (!WiFi.hostByName(host, address)) + return 0; + + return connect(address, port, host, CA_cert, cert, private_key); +} + +int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *host, const char *CA_cert, const char *cert, const char *private_key) +{ + int ret = start_ssl_client(sslclient, ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos); _lastError = ret; if (ret < 0) { log_e("start_ssl_client: %d", ret); @@ -146,7 +155,12 @@ int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *pskIdent, int WiFiClientSecure::connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey) { log_v("start_ssl_client with PSK"); - int ret = start_ssl_client(sslclient, host, port, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos); + + IPAddress address; + if (!WiFi.hostByName(host, address)) + return 0; + + int ret = start_ssl_client(sslclient, address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos); _lastError = ret; if (ret < 0) { log_e("start_ssl_client: %d", ret); diff --git a/libraries/WiFiClientSecure/src/WiFiClientSecure.h b/libraries/WiFiClientSecure/src/WiFiClientSecure.h index 2d4961042..6c967fbd0 100644 --- a/libraries/WiFiClientSecure/src/WiFiClientSecure.h +++ b/libraries/WiFiClientSecure/src/WiFiClientSecure.h @@ -55,6 +55,7 @@ public: int connect(const char *host, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key); int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey); int connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey); + int connect(IPAddress ip, uint16_t port, const char *host, const char *CA_cert, const char *cert, const char *private_key); int peek(); size_t write(uint8_t data); size_t write(const uint8_t *buf, size_t size); diff --git a/libraries/WiFiClientSecure/src/ssl_client.cpp b/libraries/WiFiClientSecure/src/ssl_client.cpp index 299c47bea..a8b570a88 100644 --- a/libraries/WiFiClientSecure/src/ssl_client.cpp +++ b/libraries/WiFiClientSecure/src/ssl_client.cpp @@ -54,7 +54,7 @@ void ssl_init(sslclient_context *ssl_client) } -int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, bool useRootCABundle, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos) +int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_t port, const char* hostname, int timeout, const char *rootCABuff, bool useRootCABundle, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos) { char buf[512]; int ret, flags; @@ -74,16 +74,11 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p return ssl_client->socket; } - IPAddress srv((uint32_t)0); - if(!WiFiGenericClass::hostByName(host, srv)){ - return -1; - } - fcntl( ssl_client->socket, F_SETFL, fcntl( ssl_client->socket, F_GETFL, 0 ) | O_NONBLOCK ); struct sockaddr_in serv_addr; memset(&serv_addr, 0, sizeof(serv_addr)); serv_addr.sin_family = AF_INET; - serv_addr.sin_addr.s_addr = srv; + serv_addr.sin_addr.s_addr = ip; serv_addr.sin_port = htons(port); if(timeout <= 0){ @@ -259,7 +254,7 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p log_v("Setting hostname for TLS session..."); // Hostname set here should match CN in server certificate - if((ret = mbedtls_ssl_set_hostname(&ssl_client->ssl_ctx, host)) != 0){ + if((ret = mbedtls_ssl_set_hostname(&ssl_client->ssl_ctx, hostname != NULL ? hostname : ip.toString().c_str())) != 0){ return handle_error(ret); } diff --git a/libraries/WiFiClientSecure/src/ssl_client.h b/libraries/WiFiClientSecure/src/ssl_client.h index 1f4179c98..9661b2224 100644 --- a/libraries/WiFiClientSecure/src/ssl_client.h +++ b/libraries/WiFiClientSecure/src/ssl_client.h @@ -30,7 +30,7 @@ typedef struct sslclient_context { void ssl_init(sslclient_context *ssl_client); -int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, bool useRootCABundle, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos); +int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_t port, const char* hostname, int timeout, const char *rootCABuff, bool useRootCABundle, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos); void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key); int data_to_read(sslclient_context *ssl_client); int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len); -- GitLab