From beaff7c09d4ae925f3d0269dc42a929468fd3745 Mon Sep 17 00:00:00 2001 From: chenyong <1521761801@qq.com> Date: Thu, 15 Nov 2018 14:39:31 +0800 Subject: [PATCH] [net][sal] Add SAL components TLS features support Signed-off-by: chenyong <1521761801@qq.com> --- components/net/Kconfig | 5 + components/net/sal_socket/SConscript | 5 +- components/net/sal_socket/impl/af_inet_at.c | 23 +- components/net/sal_socket/impl/af_inet_lwip.c | 24 +- .../net/sal_socket/impl/proto_mbedtls.c | 239 ++++++++++++++++ components/net/sal_socket/include/sal.h | 45 +-- components/net/sal_socket/include/sal_tls.h | 68 +++++ .../include/socket/sys_socket/sys/socket.h | 3 + components/net/sal_socket/src/sal_socket.c | 270 +++++++++++++++--- 9 files changed, 599 insertions(+), 83 deletions(-) create mode 100644 components/net/sal_socket/impl/proto_mbedtls.c create mode 100644 components/net/sal_socket/include/sal_tls.h diff --git a/components/net/Kconfig b/components/net/Kconfig index c302dec152..cef3ceacd3 100644 --- a/components/net/Kconfig +++ b/components/net/Kconfig @@ -21,6 +21,11 @@ config RT_USING_SAL bool "Support AT Commands stack" default y depends on AT_USING_SOCKET + + config SAL_USING_TLS + bool "Support MbedTLS protocol" + default y + depends on PKG_USING_MBEDTLS endmenu endif diff --git a/components/net/sal_socket/SConscript b/components/net/sal_socket/SConscript index e962193359..cc7634f7d9 100644 --- a/components/net/sal_socket/SConscript +++ b/components/net/sal_socket/SConscript @@ -15,10 +15,13 @@ if GetDepend('SAL_USING_LWIP'): if GetDepend('SAL_USING_AT'): src += Glob('impl/af_inet_at.c') - + if GetDepend('SAL_USING_LWIP') or GetDepend('SAL_USING_AT'): CPPPATH += [cwd + '/impl'] +if GetDepend('SAL_USING_TLS'): + src += Glob('impl/proto_mbedtls.c') + if GetDepend('SAL_USING_POSIX'): CPPPATH += [cwd + '/include/dfs_net'] src += Glob('socket/net_sockets.c') diff --git a/components/net/sal_socket/impl/af_inet_at.c b/components/net/sal_socket/impl/af_inet_at.c index c064283a61..90d1537967 100644 --- a/components/net/sal_socket/impl/af_inet_at.c +++ b/components/net/sal_socket/impl/af_inet_at.c @@ -62,7 +62,7 @@ static int at_poll(struct dfs_fd *file, struct rt_pollreq *req) } #endif -static const struct proto_ops at_inet_stream_ops = +static const struct sal_socket_ops at_socket_ops = { at_socket, at_closesocket, @@ -90,25 +90,30 @@ static int at_create(struct sal_socket *socket, int type, int protocol) //TODO Check type & protocol - socket->ops = &at_inet_stream_ops; + socket->ops = &at_socket_ops; return 0; } -static const struct proto_family at_inet_family_ops = { - "at", - AF_AT, - AF_INET, - at_create, +static struct sal_proto_ops at_proto_ops = +{ at_gethostbyname, NULL, - at_freeaddrinfo, at_getaddrinfo, + at_freeaddrinfo, +}; + +static const struct sal_proto_family at_inet_family = +{ + AF_AT, + AF_INET, + at_create, + &at_proto_ops, }; int at_inet_init(void) { - sal_proto_family_register(&at_inet_family_ops); + sal_proto_family_register(&at_inet_family); return 0; } diff --git a/components/net/sal_socket/impl/af_inet_lwip.c b/components/net/sal_socket/impl/af_inet_lwip.c index 9e726c87e9..0826b052ab 100644 --- a/components/net/sal_socket/impl/af_inet_lwip.c +++ b/components/net/sal_socket/impl/af_inet_lwip.c @@ -259,7 +259,8 @@ static int inet_poll(struct dfs_fd *file, struct rt_pollreq *req) } #endif -static const struct proto_ops lwip_inet_stream_ops = { +static const struct sal_socket_ops lwip_socket_ops = +{ inet_socket, lwip_close, lwip_bind, @@ -286,25 +287,30 @@ static int inet_create(struct sal_socket *socket, int type, int protocol) //TODO Check type & protocol - socket->ops = &lwip_inet_stream_ops; + socket->ops = &lwip_socket_ops; return 0; } -static const struct proto_family lwip_inet_family_ops = { - "lwip", - AF_INET, - AF_INET, - inet_create, +static struct sal_proto_ops lwip_proto_ops = +{ lwip_gethostbyname, lwip_gethostbyname_r, - lwip_freeaddrinfo, lwip_getaddrinfo, + lwip_freeaddrinfo, +}; + +static const struct sal_proto_family lwip_inet_family = +{ + AF_INET, + AF_INET, + inet_create, + &lwip_proto_ops, }; int lwip_inet_init(void) { - sal_proto_family_register(&lwip_inet_family_ops); + sal_proto_family_register(&lwip_inet_family); return 0; } diff --git a/components/net/sal_socket/impl/proto_mbedtls.c b/components/net/sal_socket/impl/proto_mbedtls.c new file mode 100644 index 0000000000..8ed9f67ff1 --- /dev/null +++ b/components/net/sal_socket/impl/proto_mbedtls.c @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2006-2018, RT-Thread Development Team + * + * SPDX-License-Identifier: Apache-2.0 + * + * Change Logs: + * Date Author Notes + * 2018-11-12 ChenYong First version + */ + +#include + +#ifdef RT_USING_DFS +#include +#endif + +#ifdef SAL_USING_TLS +#include +#endif +#include +#include + +#ifdef SAL_USING_TLS + +#if !defined(MBEDTLS_CONFIG_FILE) +#include +#else +#include MBEDTLS_CONFIG_FILE +#endif + +#include +#include + +#ifndef SAL_MEBDTLS_BUFFER_LEN +#define SAL_MEBDTLS_BUFFER_LEN 1024 +#endif + +static void *mebdtls_socket(int socket) +{ + MbedTLSSession *session = RT_NULL; + char *pers = "mbedtls"; + + if (socket < 0) + { + return RT_NULL; + } + + session = (MbedTLSSession *) tls_calloc(1, sizeof(MbedTLSSession)); + if (session == RT_NULL) + { + return RT_NULL; + } + + session->buffer_len = SAL_MEBDTLS_BUFFER_LEN; + session->buffer = tls_calloc(1, session->buffer_len); + if (session->buffer == RT_NULL) + { + tls_free(session); + session = RT_NULL; + + return RT_NULL; + } + + /* initialize TLS Client sesison */ + if (mbedtls_client_init(session, (void *) pers, rt_strlen(pers)) != RT_EOK) + { + mbedtls_client_close(session); + return RT_NULL; + } + session->server_fd.fd = socket; + + return (void *)session; +} + +int mbedtls_net_send_cb(void *ctx, const unsigned char *buf, size_t len) +{ + struct sal_socket *sock; + int socket, ret; + + RT_ASSERT(ctx); + RT_ASSERT(buf); + + socket = ((mbedtls_net_context *) ctx)->fd; + sock = sal_get_socket(socket); + if (sock == RT_NULL) + { + return -1; + } + + /* Register scoket sendto option to TLS send data callback */ + ret = sock->ops->sendto((int) sock->user_data, (void *)buf, len, 0, RT_NULL, RT_NULL); + if (ret < 0) + { +#ifdef RT_USING_DFS + if ((fcntl(socket, F_GETFL) & O_NONBLOCK) == O_NONBLOCK) + return MBEDTLS_ERR_SSL_WANT_WRITE; +#endif + if (errno == ECONNRESET) + return MBEDTLS_ERR_NET_CONN_RESET; + if ( errno == EINTR) + return MBEDTLS_ERR_SSL_WANT_READ; + + return MBEDTLS_ERR_NET_SEND_FAILED ; + } + + return ret; +} + +int mbedtls_net_recv_cb( void *ctx, unsigned char *buf, size_t len) +{ + struct sal_socket *sock; + int socket, ret; + + RT_ASSERT(ctx); + RT_ASSERT(buf); + + socket = ((mbedtls_net_context *) ctx)->fd; + sock = sal_get_socket(socket); + if (sock == RT_NULL) + { + return -1; + } + + /* Register scoket recvfrom option to TLS recv data callback */ + ret = sock->ops->recvfrom((int) sock->user_data, (void *)buf, len, 0, RT_NULL, RT_NULL); + if (ret < 0) + { +#ifdef RT_USING_DFS + if ((fcntl(socket, F_GETFL) & O_NONBLOCK) == O_NONBLOCK) + return MBEDTLS_ERR_SSL_WANT_WRITE; +#endif + if (errno == ECONNRESET) + return MBEDTLS_ERR_NET_CONN_RESET; + if ( errno == EINTR) + return MBEDTLS_ERR_SSL_WANT_READ; + + return MBEDTLS_ERR_NET_RECV_FAILED ; + } + + return ret; +} + +static int mbedtls_connect(void *sock) +{ + MbedTLSSession *session = RT_NULL; + int ret = 0; + + RT_ASSERT(sock); + + session = (MbedTLSSession *) sock; + + /* Set the SSL Configure infromation */ + ret = mbedtls_client_context(session); + if (ret < 0) + { + goto __exit; + } + + /* Set the underlying BIO callbacks for write, read and read-with-timeout. */ + mbedtls_ssl_set_bio(&session->ssl, &session->server_fd, mbedtls_net_send_cb, mbedtls_net_recv_cb, RT_NULL); + + while ((ret = mbedtls_ssl_handshake(&session->ssl)) != 0) + { + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) + { + goto __exit; + } + } + + /* Return the result of the certificate verification */ + ret = mbedtls_ssl_get_verify_result(&session->ssl); + if (ret != 0) + { + rt_memset(session->buffer, 0x00, session->buffer_len); + mbedtls_x509_crt_verify_info((char *)session->buffer, session->buffer_len, " ! ", ret); + goto __exit; + } + + return ret; + +__exit: + if (session) + { + mbedtls_client_close(session); + } + + return ret; +} + +static int mbedtls_closesocket(void *sock) +{ + struct sal_socket *ssock; + int socket; + + if (sock == RT_NULL) + { + return 0; + } + + socket = ((MbedTLSSession *) sock)->server_fd.fd; + ssock = sal_get_socket(socket); + if (ssock == RT_NULL) + { + return -1; + } + + /* Close TLS client session, and clean user-data in SAL socket */ + mbedtls_client_close((MbedTLSSession *) sock); + ssock->user_data_tls = RT_NULL; + + return 0; +} + +static const struct sal_proto_tls_ops mbedtls_proto_ops= +{ + RT_NULL, + mebdtls_socket, + mbedtls_connect, + (int (*)(void *sock, const void *data, size_t size)) mbedtls_client_write, + (int (*)(void *sock, void *mem, size_t len)) mbedtls_client_read, + mbedtls_closesocket, +}; + +static const struct sal_proto_tls mbedtls_proto = +{ + "mbedtls", + &mbedtls_proto_ops, +}; + +int sal_mbedtls_proto_init(void) +{ + /* register MbedTLS protocol options to SAL */ + sal_proto_tls_register(&mbedtls_proto); + + return 0; +} +INIT_COMPONENT_EXPORT(sal_mbedtls_proto_init); + +#endif /* SAL_USING_TLS */ diff --git a/components/net/sal_socket/include/sal.h b/components/net/sal_socket/include/sal.h index 68248742d6..17d44eef77 100644 --- a/components/net/sal_socket/include/sal.h +++ b/components/net/sal_socket/include/sal.h @@ -25,7 +25,7 @@ extern "C" { typedef uint32_t socklen_t; #endif -/* sal socket magic word */ +/* SAL socket magic word */ #define SAL_SOCKET_MAGIC 0x5A10 /* The maximum number of sockets structure */ @@ -38,12 +38,12 @@ typedef uint32_t socklen_t; #define SAL_PROTO_FAMILIES_NUM 4 #endif -/* sal socket offset */ +/* SAL socket offset */ #ifndef SAL_SOCKET_OFFSET #define SAL_SOCKET_OFFSET 0 #endif -struct proto_ops +struct sal_socket_ops { int (*socket) (int domain, int type, int protocol); int (*closesocket)(int s); @@ -64,30 +64,38 @@ struct proto_ops #endif }; +struct sal_proto_ops +{ + struct hostent* (*gethostbyname) (const char *name); + int (*gethostbyname_r)(const char *name, struct hostent *ret, char *buf, size_t buflen, struct hostent **result, int *h_errnop); + int (*getaddrinfo) (const char *nodename, const char *servname, const struct addrinfo *hints, struct addrinfo **res); + void (*freeaddrinfo) (struct addrinfo *ai); +}; + struct sal_socket { - uint32_t magic; /* sal socket magic word */ + uint32_t magic; /* SAL socket magic word */ - int socket; /* sal socket descriptor */ + int socket; /* SAL socket descriptor */ int domain; int type; int protocol; - const struct proto_ops *ops; /* socket options */ - void *user_data; /* specific sal socket data */ + const struct sal_socket_ops *ops; /* socket options */ + + void *user_data; /* user-specific data */ +#ifdef SAL_USING_TLS + void *user_data_tls; /* user-specific TLS data */ +#endif }; -struct proto_family +struct sal_proto_family { - char name[RT_NAME_MAX]; int family; /* primary protocol families type */ int sec_family; /* secondary protocol families type */ - int (*create)(struct sal_socket *sal_socket, int type, int protocol); /* register socket options */ + int (*create)(struct sal_socket *sal_socket, int type, int protocol); /* register socket options */ - struct hostent* (*gethostbyname) (const char *name); - int (*gethostbyname_r)(const char *name, struct hostent *ret, char *buf, size_t buflen, struct hostent **result, int *h_errnop); - void (*freeaddrinfo) (struct addrinfo *ai); - int (*getaddrinfo) (const char *nodename, const char *servname, const struct addrinfo *hints, struct addrinfo **res); + struct sal_proto_ops *ops; /* protocol family options */ }; /* SAL(Socket Abstraction Layer) initialize */ @@ -95,10 +103,11 @@ int sal_init(void); struct sal_socket *sal_get_socket(int sock); -/* protocol family register and unregister operate */ -int sal_proto_family_register(const struct proto_family *pf); -int sal_proto_family_unregister(const struct proto_family *pf); -struct proto_family *sal_proto_family_find(const char *name); +/* SAL protocol family register and unregister operate */ +int sal_proto_family_register(const struct sal_proto_family *pf); +int sal_proto_family_unregister(int family); +rt_bool_t sal_proto_family_is_registered(int family); +struct sal_proto_family *sal_proto_family_find(int family); #ifdef __cplusplus } diff --git a/components/net/sal_socket/include/sal_tls.h b/components/net/sal_socket/include/sal_tls.h new file mode 100644 index 0000000000..7681b313bd --- /dev/null +++ b/components/net/sal_socket/include/sal_tls.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2006-2018, RT-Thread Development Team + * + * SPDX-License-Identifier: Apache-2.0 + * + * Change Logs: + * Date Author Notes + * 2018-11-10 ChenYong First version + */ +#ifndef __SAL_TLS_H__ +#define __SAL_TLS_H__ + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +/* Protocol level for TLS. + * Here, the same socket protocol level for TLS as in Linux was used. + */ +#define SOL_TLS 282 + +/* Socket options for TLS */ + +/* Socket option to select TLS credentials to use. */ +#define TLS_CRET_LIST 1 +/* Socket option to set select ciphersuites to use. */ +#define TLS_CIPHERSUITE_LIST 2 +/* Socket option to set peer verification level for TLS connection. */ +#define TLS_PEER_VERIFY 3 +/* Socket option to set role for DTLS connection. */ +#define TLS_DTLS_ROLE 4 + +/* Protocol numbers for TLS protocols */ +#define PROTOCOL_TLS 256 +#define PROTOCOL_DTLS 257 + + +struct sal_proto_tls_ops +{ + int (*init)(void); + void* (*socket)(int socket); + int (*connect)(void *sock); + int (*send)(void *sock, const void *data, size_t size); + int (*recv)(void *sock, void *mem, size_t len); + int (*closesocket)(void *sock); + + int (*set_cret_list)(void *sock, const void *cert, size_t size); /* Set TLS credentials */ + int (*set_ciphersurite)(void *sock, const void* ciphersurite, size_t size); /* Set select ciphersuites */ + int (*set_peer_verify)(void *sock, const void* peer_verify, size_t size); /* Set peer verification */ + int (*set_dtls_role)(void *sock, const void *dtls_role, size_t size); /* Set role for DTLS */ +}; + +struct sal_proto_tls +{ + char name[RT_NAME_MAX]; /* TLS protocol name */ + const struct sal_proto_tls_ops *ops; /* SAL TLS protocol options */ +}; + +/* SAL TLS protocol register */ +int sal_proto_tls_register(const struct sal_proto_tls *pt); + +#ifdef __cplusplus +} +#endif + +#endif /* __SAL_TLS_H__ */ diff --git a/components/net/sal_socket/include/socket/sys_socket/sys/socket.h b/components/net/sal_socket/include/socket/sys_socket/sys/socket.h index dd2c3850a6..04470aa84a 100644 --- a/components/net/sal_socket/include/socket/sys_socket/sys/socket.h +++ b/components/net/sal_socket/include/socket/sys_socket/sys/socket.h @@ -14,6 +14,9 @@ #include #include +#ifdef SAL_USING_TLS +#include +#endif #ifdef __cplusplus extern "C" { diff --git a/components/net/sal_socket/src/sal_socket.c b/components/net/sal_socket/src/sal_socket.c index ac36cb2139..dba8e66254 100644 --- a/components/net/sal_socket/src/sal_socket.c +++ b/components/net/sal_socket/src/sal_socket.c @@ -6,6 +6,7 @@ * Change Logs: * Date Author Notes * 2018-05-23 ChenYong First version + * 2018-11-12 ChenYong Add TLS support */ #include @@ -13,6 +14,9 @@ #include #include +#ifdef SAL_USING_TLS +#include +#endif #include #define DBG_ENABLE @@ -30,13 +34,32 @@ struct sal_socket_table struct sal_socket **sockets; }; +#ifdef SAL_USING_TLS +/* The global TLS protocol options */ +static struct sal_proto_tls *proto_tls; +#endif + /* The global array of available protocol families */ -static struct proto_family proto_families[SAL_PROTO_FAMILIES_NUM]; +static struct sal_proto_family proto_families[SAL_PROTO_FAMILIES_NUM]; /* The global socket table */ static struct sal_socket_table socket_table; static struct rt_mutex sal_core_lock; static rt_bool_t init_ok = RT_FALSE; +#define IS_SOCKET_PROTO_TLS(sock) (((sock)->protocol == PROTOCOL_TLS) || \ + ((sock)->protocol == PROTOCOL_DTLS)) +#define SAL_SOCKOPS_PROTO_TLS_VALID(sock, name) (proto_tls && (proto_tls->ops->name) && IS_SOCKET_PROTO_TLS(sock)) + +#define SAL_SOCKOPT_PROTO_TLS_EXEC(sock, name, optval, optlen) \ +do \ +{ \ + if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, name)) \ + { \ + return proto_tls->ops->name((sock)->user_data_tls, (optval), (optlen)); \ + } \ +}while(0) \ + + /** * SAL (Socket Abstraction Layer) initialize. * @@ -47,7 +70,7 @@ int sal_init(void) { int cn; - if(init_ok) + if (init_ok) { LOG_D("Socket Abstraction Layer is already initialized."); return 0; @@ -73,15 +96,32 @@ int sal_init(void) } INIT_COMPONENT_EXPORT(sal_init); +/** + * This function will register TLS protocol to the global TLS protocol. + * + * @param pt TLS protocol object + * + * @return 0: TLS protocol object register success + */ +#ifdef SAL_USING_TLS +int sal_proto_tls_register(const struct sal_proto_tls *pt) +{ + RT_ASSERT(pt); + proto_tls = (struct sal_proto_tls *) pt; + + return 0; +} +#endif + /** * This function will register protocol family to the global array of protocol families. * * @param pf protocol family object * - * @return 0 : protocol family object register success - * -1 : the global array of available protocol families is full + * @return 0: protocol family object register success + * -1: the global array of available protocol families is full */ -int sal_proto_family_register(const struct proto_family *pf) +int sal_proto_family_register(const struct sal_proto_family *pf) { rt_base_t level; int idx; @@ -92,11 +132,11 @@ int sal_proto_family_register(const struct proto_family *pf) /* check protocol family is already registered */ for(idx = 0; idx < SAL_PROTO_FAMILIES_NUM; idx++) { - if(rt_strcmp(proto_families[idx].name, pf->name) == 0) + if (proto_families[idx].family == pf->family && proto_families[idx].create) { /* enable interrupt */ rt_hw_interrupt_enable(level); - LOG_E("%s protocol family is already registered!", pf->name); + LOG_E("%s protocol family is already registered!", pf->family); return -1; } } @@ -105,22 +145,22 @@ int sal_proto_family_register(const struct proto_family *pf) for(idx = 0; idx < SAL_PROTO_FAMILIES_NUM && proto_families[idx].create; idx++); /* can't find an empty protocol family entry */ - if(idx == SAL_PROTO_FAMILIES_NUM) + if (idx == SAL_PROTO_FAMILIES_NUM) { /* enable interrupt */ rt_hw_interrupt_enable(level); return -1; } - rt_strncpy(proto_families[idx].name, pf->name, rt_strlen(pf->name)); proto_families[idx].family = pf->family; proto_families[idx].sec_family = pf->sec_family; proto_families[idx].create = pf->create; - proto_families[idx].gethostbyname = pf->gethostbyname; - proto_families[idx].gethostbyname_r = pf->gethostbyname_r; - proto_families[idx].freeaddrinfo = pf->freeaddrinfo; - proto_families[idx].getaddrinfo = pf->getaddrinfo; + proto_families[idx].ops = pf->ops; + proto_families[idx].ops->gethostbyname = pf->ops->gethostbyname; + proto_families[idx].ops->gethostbyname_r = pf->ops->gethostbyname_r; + proto_families[idx].ops->freeaddrinfo = pf->ops->freeaddrinfo; + proto_families[idx].ops->getaddrinfo = pf->ops->getaddrinfo; /* enable interrupt */ rt_hw_interrupt_enable(level); @@ -136,17 +176,17 @@ int sal_proto_family_register(const struct proto_family *pf) * @return >=0 : unregister protocol family index * -1 : unregister failed */ -int sal_proto_family_unregister(const struct proto_family *pf) +int sal_proto_family_unregister(int family) { int idx = 0; - RT_ASSERT(pf != RT_NULL); + RT_ASSERT(family > 0 && family < AF_MAX); for(idx = 0; idx < SAL_PROTO_FAMILIES_NUM; idx++) { - if(rt_strcmp(proto_families[idx].name, pf->name) == 0) + if (proto_families[idx].family == family && proto_families[idx].create) { - rt_memset(&proto_families[idx], 0x00, sizeof(struct proto_family)); + rt_memset(&proto_families[idx], 0x00, sizeof(struct sal_proto_family)); return idx; } @@ -156,21 +196,46 @@ int sal_proto_family_unregister(const struct proto_family *pf) } /** - * This function will get protocol family by name. + * This function will judge whether protocol family is registered * - * @param name protocol family name + * @param family protocol family number + * + * @return 1: protocol family is registered + * 0: protocol family is not registered + */ +rt_bool_t sal_proto_family_is_registered(int family) +{ + int idx = 0; + + RT_ASSERT(family > 0 && family < AF_MAX); + + for (idx = 0; idx < SAL_PROTO_FAMILIES_NUM; idx++) + { + if (proto_families[idx].family == family && proto_families[idx].create) + { + return RT_TRUE; + } + } + + return RT_FALSE; +} + +/** + * This function will get protocol family object by family number. + * + * @param family protocol family number * * @return protocol family object */ -struct proto_family *sal_proto_family_find(const char *name) +struct sal_proto_family *sal_proto_family_find(int family) { int idx = 0; - RT_ASSERT(name != RT_NULL); + RT_ASSERT(family > 0 && family < AF_MAX); for (idx = 0; idx < SAL_PROTO_FAMILIES_NUM; idx++) { - if (rt_strcmp(proto_families[idx].name, name) == 0) + if (proto_families[idx].family == family && proto_families[idx].create) { return &proto_families[idx]; } @@ -238,7 +303,7 @@ static void sal_unlock(void) * * @return protocol family structure address */ -static struct proto_family *get_proto_family(int family) +static struct sal_proto_family *get_proto_family(int family) { int idx; @@ -278,7 +343,7 @@ static struct proto_family *get_proto_family(int family) static int socket_init(int family, int type, int protocol, struct sal_socket **res) { struct sal_socket *sock; - struct proto_family *pf; + struct sal_proto_family *pf; if (family < 0 || family > AF_MAX) { @@ -383,6 +448,11 @@ static int socket_new(void) sock = st->sockets[idx]; sock->socket = idx + SAL_SOCKET_OFFSET; sock->magic = SAL_SOCKET_MAGIC; + sock->ops = RT_NULL; + sock->user_data = RT_NULL; +#ifdef SAL_USING_TLS + sock->user_data_tls = RT_NULL; +#endif __result: sal_unlock(); @@ -474,6 +544,15 @@ int sal_shutdown(int socket, int how) if (sock->ops->shutdown((int) sock->user_data, how) == 0) { +#ifdef SAL_USING_TLS + if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, closesocket)) + { + if (proto_tls->ops->closesocket(sock->user_data_tls) < 0) + { + return -1; + } + } +#endif rt_free(sock); socket_table.sockets[socket] = RT_NULL; return 0; @@ -551,12 +630,46 @@ int sal_setsockopt(int socket, int level, int optname, const void *optval, sockl return -RT_ENOSYS; } +#ifdef SAL_USING_TLS + if (level == SOL_TLS) + { + switch (optname) + { + case TLS_CRET_LIST: + SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_cret_list, optval, optlen); + break; + + case TLS_CIPHERSUITE_LIST: + SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_ciphersurite, optval, optlen); + break; + + case TLS_PEER_VERIFY: + SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_peer_verify, optval, optlen); + break; + + case TLS_DTLS_ROLE: + SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_dtls_role, optval, optlen); + break; + + default: + return -1; + } + + return 0; + } + else + { + return sock->ops->setsockopt((int) sock->user_data, level, optname, optval, optlen); + } +#else return sock->ops->setsockopt((int) sock->user_data, level, optname, optval, optlen); +#endif /* SAL_USING_TLS */ } int sal_connect(int socket, const struct sockaddr *name, socklen_t namelen) { struct sal_socket *sock; + int ret; sock = sal_get_socket(socket); if (!sock) @@ -569,7 +682,20 @@ int sal_connect(int socket, const struct sockaddr *name, socklen_t namelen) return -RT_ENOSYS; } - return sock->ops->connect((int) sock->user_data, name, namelen); + ret = sock->ops->connect((int) sock->user_data, name, namelen); +#ifdef SAL_USING_TLS + if (ret >= 0 && SAL_SOCKOPS_PROTO_TLS_VALID(sock, connect)) + { + if (proto_tls->ops->connect(sock->user_data_tls) < 0) + { + return -1; + } + + return ret; + } +#endif + + return ret; } int sal_listen(int socket, int backlog) @@ -606,7 +732,24 @@ int sal_recvfrom(int socket, void *mem, size_t len, int flags, return -RT_ENOSYS; } +#ifdef SAL_USING_TLS + if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, recv)) + { + int ret; + + if ((ret = proto_tls->ops->recv(sock->user_data_tls, mem, len)) < 0) + { + return -1; + } + return ret; + } + else + { + return sock->ops->recvfrom((int) sock->user_data, mem, len, flags, from, fromlen); + } +#else return sock->ops->recvfrom((int) sock->user_data, mem, len, flags, from, fromlen); +#endif } int sal_sendto(int socket, const void *dataptr, size_t size, int flags, @@ -625,7 +768,24 @@ int sal_sendto(int socket, const void *dataptr, size_t size, int flags, return -RT_ENOSYS; } +#ifdef SAL_USING_TLS + if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, send)) + { + int ret; + + if ((ret = proto_tls->ops->send(sock->user_data_tls, dataptr, size)) < 0) + { + return -1; + } + return ret; + } + else + { + return sock->ops->sendto((int) sock->user_data, dataptr, size, flags, to, tolen); + } +#else return sock->ops->sendto((int) sock->user_data, dataptr, size, flags, to, tolen); +#endif } int sal_socket(int domain, int type, int protocol) @@ -657,8 +817,17 @@ int sal_socket(int domain, int type, int protocol) proto_socket = sock->ops->socket(domain, type, protocol); if (proto_socket >= 0) { +#ifdef SAL_USING_TLS + if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, socket)) + { + sock->user_data_tls = proto_tls->ops->socket(proto_socket); + if (sock->user_data_tls == RT_NULL) + { + return -1; + } + } +#endif sock->user_data = (void *) proto_socket; - return sock->socket; } @@ -682,6 +851,15 @@ int sal_closesocket(int socket) if (sock->ops->closesocket((int) sock->user_data) == 0) { +#ifdef SAL_USING_TLS + if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, closesocket)) + { + if (proto_tls->ops->closesocket(sock->user_data_tls) < 0) + { + return -1; + } + } +#endif rt_free(sock); socket_table.sockets[socket] = RT_NULL; return 0; @@ -736,9 +914,9 @@ struct hostent *sal_gethostbyname(const char *name) for (i = 0; i < SAL_PROTO_FAMILIES_NUM; ++i) { - if (proto_families[i].gethostbyname) + if (proto_families[i].ops && proto_families[i].ops->gethostbyname) { - hst = proto_families[i].gethostbyname(name); + hst = proto_families[i].ops->gethostbyname(name); if (hst != RT_NULL) { return hst; @@ -756,9 +934,9 @@ int sal_gethostbyname_r(const char *name, struct hostent *ret, char *buf, for (i = 0; i < SAL_PROTO_FAMILIES_NUM; ++i) { - if (proto_families[i].gethostbyname_r) + if (proto_families[i].ops && proto_families[i].ops->gethostbyname_r) { - res = proto_families[i].gethostbyname_r(name, ret, buf, buflen, result, h_errnop); + res = proto_families[i].ops->gethostbyname_r(name, ret, buf, buflen, result, h_errnop); if (res == 0) { return res; @@ -769,20 +947,6 @@ int sal_gethostbyname_r(const char *name, struct hostent *ret, char *buf, return -1; } -void sal_freeaddrinfo(struct addrinfo *ai) -{ - int i; - - for (i = 0; i < SAL_PROTO_FAMILIES_NUM; ++i) - { - if (proto_families[i].freeaddrinfo) - { - proto_families[i].freeaddrinfo(ai); - return; - } - } -} - int sal_getaddrinfo(const char *nodename, const char *servname, const struct addrinfo *hints, @@ -792,9 +956,9 @@ int sal_getaddrinfo(const char *nodename, for (i = 0; i < SAL_PROTO_FAMILIES_NUM; ++i) { - if (proto_families[i].getaddrinfo) + if (proto_families[i].ops && proto_families[i].ops->getaddrinfo) { - ret = proto_families[i].getaddrinfo(nodename, servname, hints, res); + ret = proto_families[i].ops->getaddrinfo(nodename, servname, hints, res); if (ret == 0) { return ret; @@ -804,3 +968,17 @@ int sal_getaddrinfo(const char *nodename, return -1; } + +void sal_freeaddrinfo(struct addrinfo *ai) +{ + int i; + + for (i = 0; i < SAL_PROTO_FAMILIES_NUM; ++i) + { + if (proto_families[i].ops && proto_families[i].ops->freeaddrinfo) + { + proto_families[i].ops->freeaddrinfo(ai); + return; + } + } +} -- GitLab