提交 ca8c71ba 编写于 作者: M Matt Caswell

Add some tests for the new TLSv1.3 PSK code

Reviewed-by: NRich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/3670)
上级 011d768a
...@@ -1927,6 +1927,183 @@ static int test_ciphersuite_change(void) ...@@ -1927,6 +1927,183 @@ static int test_ciphersuite_change(void)
return testresult; return testresult;
} }
static SSL_SESSION *psk = NULL;
static const char *pskid = "Identity";
static const char *srvid;
static int use_session_cb_cnt = 0;
static int find_session_cb_cnt = 0;
static int use_session_cb(SSL *ssl, const EVP_MD *md, const unsigned char **id,
size_t *idlen, SSL_SESSION **sess)
{
use_session_cb_cnt++;
/* The first call should always have a NULL md */
if (use_session_cb_cnt == 1 && md != NULL)
return 0;
/* The second call should always have an md */
if (use_session_cb_cnt == 2 && md == NULL)
return 0;
/* We should only be called a maximum of twice */
if (use_session_cb_cnt == 3)
return 0;
if (psk != NULL)
SSL_SESSION_up_ref(psk);
*sess = psk;
*id = (const unsigned char *)pskid;
*idlen = strlen(pskid);
return 1;
}
static int find_session_cb(SSL *ssl, const unsigned char *identity,
size_t identity_len, SSL_SESSION **sess)
{
find_session_cb_cnt++;
/* We should only ever be called a maximum of twice per connection */
if (find_session_cb_cnt > 2)
return 0;
if (psk == NULL)
return 0;
/* Identity should match that set by the client */
if (strlen(srvid) != identity_len
|| strncmp(srvid, (const char *)identity, identity_len) != 0) {
/* No PSK found, continue but without a PSK */
*sess = NULL;
return 1;
}
SSL_SESSION_up_ref(psk);
*sess = psk;
return 1;
}
#define TLS13_AES_256_GCM_SHA384_BYTES ((const unsigned char *)"\x13\x02")
static int test_tls13_psk(void)
{
SSL_CTX *sctx = NULL, *cctx = NULL;
SSL *serverssl = NULL, *clientssl = NULL;
const SSL_CIPHER *cipher = NULL;
const unsigned char key[] = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f
};
int testresult = 0;
if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(),
TLS_client_method(), &sctx,
&cctx, cert, privkey)))
goto end;
SSL_CTX_set_psk_use_session_callback(cctx, use_session_cb);
SSL_CTX_set_psk_find_session_callback(sctx, find_session_cb);
srvid = pskid;
/* Check we can create a connection if callback decides not to send a PSK */
if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
NULL, NULL))
|| !TEST_true(create_ssl_connection(serverssl, clientssl,
SSL_ERROR_NONE))
|| !TEST_false(SSL_session_reused(clientssl))
|| !TEST_false(SSL_session_reused(serverssl))
|| !TEST_true(use_session_cb_cnt == 1)
|| !TEST_true(find_session_cb_cnt == 0))
goto end;
shutdown_ssl_connection(serverssl, clientssl);
serverssl = clientssl = NULL;
use_session_cb_cnt = 0;
if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
NULL, NULL)))
goto end;
/* Create the PSK */
cipher = SSL_CIPHER_find(clientssl, TLS13_AES_256_GCM_SHA384_BYTES);
psk = SSL_SESSION_new();
if (!TEST_ptr(psk)
|| !TEST_ptr(cipher)
|| !TEST_true(SSL_SESSION_set1_master_key(psk, key, sizeof(key)))
|| !TEST_true(SSL_SESSION_set_cipher(psk, cipher))
|| !TEST_true(SSL_SESSION_set_protocol_version(psk,
TLS1_3_VERSION)))
goto end;
/* Check we can create a connection and the PSK is used */
if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))
|| !TEST_true(SSL_session_reused(clientssl))
|| !TEST_true(SSL_session_reused(serverssl))
|| !TEST_true(use_session_cb_cnt == 1)
|| !TEST_true(find_session_cb_cnt == 1))
goto end;
shutdown_ssl_connection(serverssl, clientssl);
serverssl = clientssl = NULL;
use_session_cb_cnt = find_session_cb_cnt = 0;
if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
NULL, NULL)))
goto end;
/* Force an HRR */
if (!TEST_true(SSL_set1_groups_list(serverssl, "P-256")))
goto end;
/*
* Check we can create a connection, the PSK is used and the callbacks are
* called twice.
*/
if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))
|| !TEST_true(SSL_session_reused(clientssl))
|| !TEST_true(SSL_session_reused(serverssl))
|| !TEST_true(use_session_cb_cnt == 2)
|| !TEST_true(find_session_cb_cnt == 2))
goto end;
shutdown_ssl_connection(serverssl, clientssl);
serverssl = clientssl = NULL;
use_session_cb_cnt = find_session_cb_cnt = 0;
/*
* Check that if the server rejects the PSK we can still connect, but with
* a full handshake
*/
srvid = "Dummy Identity";
if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
NULL, NULL))
|| !TEST_true(create_ssl_connection(serverssl, clientssl,
SSL_ERROR_NONE))
|| !TEST_false(SSL_session_reused(clientssl))
|| !TEST_false(SSL_session_reused(serverssl))
|| !TEST_true(use_session_cb_cnt == 1)
|| !TEST_true(find_session_cb_cnt == 1))
goto end;
shutdown_ssl_connection(serverssl, clientssl);
serverssl = clientssl = NULL;
testresult = 1;
end:
SSL_SESSION_free(psk);
SSL_free(serverssl);
SSL_free(clientssl);
SSL_CTX_free(sctx);
SSL_CTX_free(cctx);
return testresult;
}
#endif /* OPENSSL_NO_TLS1_3 */ #endif /* OPENSSL_NO_TLS1_3 */
static int clntaddoldcb = 0; static int clntaddoldcb = 0;
...@@ -2352,6 +2529,7 @@ int test_main(int argc, char *argv[]) ...@@ -2352,6 +2529,7 @@ int test_main(int argc, char *argv[])
#endif #endif
#ifndef OPENSSL_NO_TLS1_3 #ifndef OPENSSL_NO_TLS1_3
ADD_TEST(test_ciphersuite_change); ADD_TEST(test_ciphersuite_change);
ADD_TEST(test_tls13_psk);
ADD_ALL_TESTS(test_custom_exts, 5); ADD_ALL_TESTS(test_custom_exts, 5);
#else #else
ADD_ALL_TESTS(test_custom_exts, 3); ADD_ALL_TESTS(test_custom_exts, 3);
......
...@@ -661,3 +661,11 @@ int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want) ...@@ -661,3 +661,11 @@ int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
return 1; return 1;
} }
void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl)
{
SSL_shutdown(clientssl);
SSL_shutdown(serverssl);
SSL_free(serverssl);
SSL_free(clientssl);
}
...@@ -18,6 +18,7 @@ int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm, ...@@ -18,6 +18,7 @@ int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl, int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
SSL **cssl, BIO *s_to_c_fbio, BIO *c_to_s_fbio); SSL **cssl, BIO *s_to_c_fbio, BIO *c_to_s_fbio);
int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want); int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want);
void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl);
/* Note: Not thread safe! */ /* Note: Not thread safe! */
const BIO_METHOD *bio_f_tls_dump_filter(void); const BIO_METHOD *bio_f_tls_dump_filter(void);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册