diff --git a/test/sslapitest.c b/test/sslapitest.c index ce24fad51af4117f16ab767d7c93ae0c93bf3d4b..c183ca8108d9ab6967458dedbb81e766d7e7500e 100644 --- a/test/sslapitest.c +++ b/test/sslapitest.c @@ -1415,6 +1415,8 @@ static const char *srvid; static int use_session_cb_cnt = 0; static int find_session_cb_cnt = 0; +static int psk_client_cb_cnt = 0; +static int psk_server_cb_cnt = 0; static int use_session_cb(SSL *ssl, const EVP_MD *md, const unsigned char **id, size_t *idlen, SSL_SESSION **sess) @@ -1447,6 +1449,34 @@ static int use_session_cb(SSL *ssl, const EVP_MD *md, const unsigned char **id, return 1; } +static unsigned int psk_client_cb(SSL *ssl, const char *hint, char *id, + unsigned int max_id_len, + unsigned char *psk, + unsigned int max_psk_len) +{ + unsigned int psklen = 0; + + psk_client_cb_cnt++; + + if (strlen(pskid) + 1 > max_id_len) + return 0; + + /* We should only ever be called a maximum of twice per connection */ + if (psk_client_cb_cnt > 2) + return 0; + + if (clientpsk == NULL) + return 0; + + /* We'll reuse the PSK we set up for TLSv1.3 */ + if (SSL_SESSION_get_master_key(clientpsk, NULL, 0) > max_psk_len) + return 0; + psklen = SSL_SESSION_get_master_key(clientpsk, psk, max_psk_len); + strncpy(id, pskid, max_id_len); + + return psklen; +} + static int find_session_cb(SSL *ssl, const unsigned char *identity, size_t identity_len, SSL_SESSION **sess) { @@ -1473,6 +1503,33 @@ static int find_session_cb(SSL *ssl, const unsigned char *identity, return 1; } +static unsigned int psk_server_cb(SSL *ssl, const char *identity, + unsigned char *psk, unsigned int max_psk_len) +{ + unsigned int psklen = 0; + + psk_server_cb_cnt++; + + /* We should only ever be called a maximum of twice per connection */ + if (find_session_cb_cnt > 2) + return 0; + + if (serverpsk == NULL) + return 0; + + /* Identity should match that set by the client */ + if (strcmp(srvid, identity) != 0) { + return 0; + } + + /* We'll reuse the PSK we set up for TLSv1.3 */ + if (SSL_SESSION_get_master_key(serverpsk, NULL, 0) > max_psk_len) + return 0; + psklen = SSL_SESSION_get_master_key(serverpsk, psk, max_psk_len); + + return psklen; +} + #define MSG1 "Hello" #define MSG2 "World." #define MSG3 "This" @@ -1482,6 +1539,7 @@ static int find_session_cb(SSL *ssl, const unsigned char *identity, #define MSG7 "message." #define TLS13_AES_256_GCM_SHA384_BYTES ((const unsigned char *)"\x13\x02") +#define TLS13_AES_128_GCM_SHA256_BYTES ((const unsigned char *)"\x13\x01") /* * Helper method to setup objects for early data test. Caller frees objects on @@ -2440,7 +2498,7 @@ static int test_ciphersuite_change(void) return testresult; } -static int test_tls13_psk(void) +static int test_tls13_psk(int idx) { SSL_CTX *sctx = NULL, *cctx = NULL; SSL *serverssl = NULL, *clientssl = NULL; @@ -2458,11 +2516,31 @@ static int test_tls13_psk(void) &cctx, cert, privkey))) goto end; - SSL_CTX_set_psk_use_session_callback(cctx, use_session_cb); - SSL_CTX_set_psk_find_session_callback(sctx, find_session_cb); + /* + * We use a ciphersuite with SHA256 to ease testing old style PSK callbacks + * which will always default to SHA256 + */ + if (!TEST_true(SSL_CTX_set_cipher_list(cctx, "TLS13-AES-128-GCM-SHA256"))) + goto end; + + /* + * Test 0: New style callbacks only + * Test 1: New and old style callbacks (only the new ones should be used) + * Test 2: Old style callbacks only + */ + if (idx == 0 || idx == 1) { + SSL_CTX_set_psk_use_session_callback(cctx, use_session_cb); + SSL_CTX_set_psk_find_session_callback(sctx, find_session_cb); + } + if (idx == 1 || idx == 2) { + SSL_CTX_set_psk_client_callback(cctx, psk_client_cb); + SSL_CTX_set_psk_server_callback(sctx, psk_server_cb); + } srvid = pskid; use_session_cb_cnt = 0; find_session_cb_cnt = 0; + psk_client_cb_cnt = 0; + psk_server_cb_cnt = 0; /* Check we can create a connection if callback decides not to send a PSK */ if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl, @@ -2470,21 +2548,37 @@ static int test_tls13_psk(void) || !TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)) || !TEST_false(SSL_session_reused(clientssl)) - || !TEST_false(SSL_session_reused(serverssl)) - || !TEST_true(use_session_cb_cnt == 1) - || !TEST_true(find_session_cb_cnt == 0)) + || !TEST_false(SSL_session_reused(serverssl))) goto end; + if (idx == 0 || idx == 1) { + if (!TEST_true(use_session_cb_cnt == 1) + || !TEST_true(find_session_cb_cnt == 0) + /* + * If no old style callback then below should be 0 + * otherwise 1 + */ + || !TEST_true(psk_client_cb_cnt == idx) + || !TEST_true(psk_server_cb_cnt == 0)) + goto end; + } else { + if (!TEST_true(use_session_cb_cnt == 0) + || !TEST_true(find_session_cb_cnt == 0) + || !TEST_true(psk_client_cb_cnt == 1) + || !TEST_true(psk_server_cb_cnt == 0)) + goto end; + } + shutdown_ssl_connection(serverssl, clientssl); serverssl = clientssl = NULL; - use_session_cb_cnt = 0; + use_session_cb_cnt = psk_client_cb_cnt = 0; if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl, NULL, NULL))) goto end; /* Create the PSK */ - cipher = SSL_CIPHER_find(clientssl, TLS13_AES_256_GCM_SHA384_BYTES); + cipher = SSL_CIPHER_find(clientssl, TLS13_AES_128_GCM_SHA256_BYTES); clientpsk = SSL_SESSION_new(); if (!TEST_ptr(clientpsk) || !TEST_ptr(cipher) @@ -2500,14 +2594,27 @@ static int test_tls13_psk(void) /* Check we can create a connection and the PSK is used */ if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)) || !TEST_true(SSL_session_reused(clientssl)) - || !TEST_true(SSL_session_reused(serverssl)) - || !TEST_true(use_session_cb_cnt == 1) - || !TEST_true(find_session_cb_cnt == 1)) + || !TEST_true(SSL_session_reused(serverssl))) goto end; + if (idx == 0 || idx == 1) { + if (!TEST_true(use_session_cb_cnt == 1) + || !TEST_true(find_session_cb_cnt == 1) + || !TEST_true(psk_client_cb_cnt == 0) + || !TEST_true(psk_server_cb_cnt == 0)) + goto end; + } else { + if (!TEST_true(use_session_cb_cnt == 0) + || !TEST_true(find_session_cb_cnt == 0) + || !TEST_true(psk_client_cb_cnt == 1) + || !TEST_true(psk_server_cb_cnt == 1)) + goto end; + } + shutdown_ssl_connection(serverssl, clientssl); serverssl = clientssl = NULL; use_session_cb_cnt = find_session_cb_cnt = 0; + psk_client_cb_cnt = psk_server_cb_cnt = 0; if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl, NULL, NULL))) @@ -2523,14 +2630,27 @@ static int test_tls13_psk(void) */ if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)) || !TEST_true(SSL_session_reused(clientssl)) - || !TEST_true(SSL_session_reused(serverssl)) - || !TEST_true(use_session_cb_cnt == 2) - || !TEST_true(find_session_cb_cnt == 2)) + || !TEST_true(SSL_session_reused(serverssl))) goto end; + if (idx == 0 || idx == 1) { + if (!TEST_true(use_session_cb_cnt == 2) + || !TEST_true(find_session_cb_cnt == 2) + || !TEST_true(psk_client_cb_cnt == 0) + || !TEST_true(psk_server_cb_cnt == 0)) + goto end; + } else { + if (!TEST_true(use_session_cb_cnt == 0) + || !TEST_true(find_session_cb_cnt == 0) + || !TEST_true(psk_client_cb_cnt == 2) + || !TEST_true(psk_server_cb_cnt == 2)) + goto end; + } + shutdown_ssl_connection(serverssl, clientssl); serverssl = clientssl = NULL; use_session_cb_cnt = find_session_cb_cnt = 0; + psk_client_cb_cnt = psk_server_cb_cnt = 0; /* * Check that if the server rejects the PSK we can still connect, but with @@ -2542,11 +2662,27 @@ static int test_tls13_psk(void) || !TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)) || !TEST_false(SSL_session_reused(clientssl)) - || !TEST_false(SSL_session_reused(serverssl)) - || !TEST_true(use_session_cb_cnt == 1) - || !TEST_true(find_session_cb_cnt == 1)) + || !TEST_false(SSL_session_reused(serverssl))) goto end; + if (idx == 0 || idx == 1) { + if (!TEST_true(use_session_cb_cnt == 1) + || !TEST_true(find_session_cb_cnt == 1) + || !TEST_true(psk_client_cb_cnt == 0) + /* + * If no old style callback then below should be 0 + * otherwise 1 + */ + || !TEST_true(psk_server_cb_cnt == idx)) + goto end; + } else { + if (!TEST_true(use_session_cb_cnt == 0) + || !TEST_true(find_session_cb_cnt == 0) + || !TEST_true(psk_client_cb_cnt == 1) + || !TEST_true(psk_server_cb_cnt == 1)) + goto end; + } + shutdown_ssl_connection(serverssl, clientssl); serverssl = clientssl = NULL; testresult = 1; @@ -3506,7 +3642,7 @@ int setup_tests(void) #endif #ifndef OPENSSL_NO_TLS1_3 ADD_TEST(test_ciphersuite_change); - ADD_TEST(test_tls13_psk); + ADD_ALL_TESTS(test_tls13_psk, 3); ADD_ALL_TESTS(test_custom_exts, 5); ADD_TEST(test_stateless); ADD_TEST(test_pha_key_update);