diff --git a/libraries/WiFiClientSecure/README.md b/libraries/WiFiClientSecure/README.md index c7582c15d997d718060e6f30437664b2f74ccc5a..fd5f4f51938b9d1a1880cd290160f04b0da5b13e 100644 --- a/libraries/WiFiClientSecure/README.md +++ b/libraries/WiFiClientSecure/README.md @@ -66,3 +66,18 @@ To use PSK: encryption for the connection Please see the WiFiClientPSK example. + +Specifying the ALPN Protocol +---------------------------- + +Application-Layer Protocol Negotiation (ALPN) is a Transport Layer Security (TLS) extension that allows +the application layer to negotiate which protocol should be performed over a secure connection in a manner +that avoids additional round trips and which is independent of the application-layer protocols. + +For example, this is used with AWS IoT Custom Authorizers where an MQTT client must set the ALPN protocol to ```mqtt```: + +``` +const char *aws_protos[] = {"mqtt", NULL}; +... +wiFiClient.setAlpnProtocols(aws_protos); +``` \ No newline at end of file diff --git a/libraries/WiFiClientSecure/keywords.txt b/libraries/WiFiClientSecure/keywords.txt index b1bf2c7388a67895e2f822e2a1f4a836705edfc4..4bab096dbd56631eaecf07ccee4088d5e2cd7a27 100644 --- a/libraries/WiFiClientSecure/keywords.txt +++ b/libraries/WiFiClientSecure/keywords.txt @@ -29,6 +29,7 @@ connected KEYWORD2 setCACert KEYWORD2 setCertificate KEYWORD2 setPrivateKey KEYWORD2 +setAlpnProtocols KEYWORD2 ####################################### # Constants (LITERAL1) diff --git a/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp b/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp index 26ac7e431cc6b3010279ee2e9bc2754a2f1c227b..4f19d57a4d0738cfb5107adebae399cc9bc515f9 100644 --- a/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp +++ b/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp @@ -43,6 +43,7 @@ WiFiClientSecure::WiFiClientSecure() _pskIdent = NULL; _psKey = NULL; next = NULL; + _alpn_protos = NULL; } @@ -66,6 +67,7 @@ WiFiClientSecure::WiFiClientSecure(int sock) _pskIdent = NULL; _psKey = NULL; next = NULL; + _alpn_protos = NULL; } WiFiClientSecure::~WiFiClientSecure() @@ -127,7 +129,7 @@ int WiFiClientSecure::connect(const char *host, uint16_t port, const char *CA_ce if(_timeout > 0){ sslclient->handshake_timeout = _timeout; } - int ret = start_ssl_client(sslclient, host, port, _timeout, CA_cert, cert, private_key, NULL, NULL, _use_insecure); + int ret = start_ssl_client(sslclient, host, port, _timeout, CA_cert, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos); _lastError = ret; if (ret < 0) { log_e("start_ssl_client: %d", ret); @@ -147,7 +149,7 @@ int WiFiClientSecure::connect(const char *host, uint16_t port, const char *pskId if(_timeout > 0){ sslclient->handshake_timeout = _timeout; } - int ret = start_ssl_client(sslclient, host, port, _timeout, NULL, NULL, NULL, pskIdent, psKey, _use_insecure); + int ret = start_ssl_client(sslclient, host, port, _timeout, NULL, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos); _lastError = ret; if (ret < 0) { log_e("start_ssl_client: %d", ret); @@ -341,3 +343,8 @@ void WiFiClientSecure::setHandshakeTimeout(unsigned long handshake_timeout) { sslclient->handshake_timeout = handshake_timeout * 1000; } + +void WiFiClientSecure::setAlpnProtocols(const char **alpn_protos) +{ + _alpn_protos = alpn_protos; +} diff --git a/libraries/WiFiClientSecure/src/WiFiClientSecure.h b/libraries/WiFiClientSecure/src/WiFiClientSecure.h index f27df2fd29e76e7b30ca84f48ec59ff8b9ddadfd..bba94ceffbc957bd0f9b66506115b829b34bb521 100644 --- a/libraries/WiFiClientSecure/src/WiFiClientSecure.h +++ b/libraries/WiFiClientSecure/src/WiFiClientSecure.h @@ -39,6 +39,7 @@ protected: const char *_private_key; const char *_pskIdent; // identity for PSK cipher suites const char *_psKey; // key in hex for PSK cipher suites + const char **_alpn_protos; public: WiFiClientSecure *next; @@ -73,6 +74,7 @@ public: bool loadPrivateKey(Stream& stream, size_t size); bool verify(const char* fingerprint, const char* domain_name); void setHandshakeTimeout(unsigned long handshake_timeout); + void setAlpnProtocols(const char **alpn_protos); const mbedtls_x509_crt* getPeerCertificate() { return mbedtls_ssl_get_peer_cert(&sslclient->ssl_ctx); }; bool getFingerprintSHA256(uint8_t sha256_result[32]) { return get_peer_fingerprint(sslclient, sha256_result); }; int setTimeout(uint32_t seconds){ return 0; } diff --git a/libraries/WiFiClientSecure/src/ssl_client.cpp b/libraries/WiFiClientSecure/src/ssl_client.cpp index 31f839844e4fa9619ed42d78e8646b979cfa381c..c910206b3c9e4a18e226268293bd4ae0bf6708da 100644 --- a/libraries/WiFiClientSecure/src/ssl_client.cpp +++ b/libraries/WiFiClientSecure/src/ssl_client.cpp @@ -51,7 +51,7 @@ void ssl_init(sslclient_context *ssl_client) } -int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure) +int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos) { char buf[512]; int ret, flags; @@ -156,6 +156,13 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p return handle_error(ret); } + if (alpn_protos != NULL) { + log_v("Setting ALPN protocols"); + if ((ret = mbedtls_ssl_conf_alpn_protocols(&ssl_client->ssl_conf, alpn_protos) ) != 0) { + return handle_error(ret); + } + } + // MBEDTLS_SSL_VERIFY_REQUIRED if a CA certificate is defined on Arduino IDE and // MBEDTLS_SSL_VERIFY_NONE if not. diff --git a/libraries/WiFiClientSecure/src/ssl_client.h b/libraries/WiFiClientSecure/src/ssl_client.h index 8a4cc502a471a4d06ec22ca15a308d2ac14e106b..d6be76d18a12e3c741e0a5ae4fec30d04eae5ddd 100644 --- a/libraries/WiFiClientSecure/src/ssl_client.h +++ b/libraries/WiFiClientSecure/src/ssl_client.h @@ -29,7 +29,7 @@ typedef struct sslclient_context { void ssl_init(sslclient_context *ssl_client); -int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure); +int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos); void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key); int data_to_read(sslclient_context *ssl_client); int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len);