sal_socket.c 29.6 KB
Newer Older
1
/*
2
 * Copyright (c) 2006-2018, RT-Thread Development Team
3
 *
4
 * SPDX-License-Identifier: Apache-2.0
5 6 7 8
 *
 * Change Logs:
 * Date           Author       Notes
 * 2018-05-23     ChenYong     First version
9
 * 2018-11-12     ChenYong     Add TLS support
10 11 12 13
 */

#include <rtthread.h>
#include <rthw.h>
14
#include <sys/time.h>
15 16 17

#include <sal_socket.h>
#include <sal_netdb.h>
18 19 20
#ifdef SAL_USING_TLS
#include <sal_tls.h>
#endif
21
#include <sal.h>
22
#include <netdev.h>
23

24 25 26 27 28 29 30 31 32
#include <ipc/workqueue.h>

/* check system workqueue stack size */
#if RT_SYSTEM_WORKQUEUE_STACKSIZE < 1536
#error "The system workqueue stack size must more than 1536 bytes"
#endif

#define DBG_TAG                        "sal.skt"
#define DBG_LVL                        DBG_INFO
33 34
#include <rtdbg.h>

35 36
#define SOCKET_TABLE_STEP_LEN          4

37 38 39 40 41 42 43
/* the socket table used to dynamic allocate sockets */
struct sal_socket_table
{
    uint32_t max_socket;
    struct sal_socket **sockets;
};

44 45 46 47 48
#ifdef SAL_USING_TLS
/* The global TLS protocol options */
static struct sal_proto_tls *proto_tls;
#endif

49 50 51 52 53
/* The global socket table */
static struct sal_socket_table socket_table;
static struct rt_mutex sal_core_lock;
static rt_bool_t init_ok = RT_FALSE;

54 55 56 57 58
#define IS_SOCKET_PROTO_TLS(sock)                (((sock)->protocol == PROTOCOL_TLS) || \
                                                 ((sock)->protocol == PROTOCOL_DTLS))
#define SAL_SOCKOPS_PROTO_TLS_VALID(sock, name)  (proto_tls && (proto_tls->ops->name) && IS_SOCKET_PROTO_TLS(sock))

#define SAL_SOCKOPT_PROTO_TLS_EXEC(sock, name, optval, optlen)                    \
59 60
do {                                                                              \
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, name)){                                 \
61 62 63 64
        return proto_tls->ops->name((sock)->user_data_tls, (optval), (optlen));   \
    }                                                                             \
}while(0)                                                                         \

65 66 67 68 69 70 71 72
#define SAL_SOCKET_OBJ_GET(sock, socket)                                          \
do {                                                                              \
    (sock) = sal_get_socket(socket);                                              \
    if ((sock) == RT_NULL) {                                                      \
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

73
#define SAL_NETDEV_IS_UP(netdev)                                                  \
74
do {                                                                              \
75
    if (!netdev_is_up(netdev)) {                                                  \
76 77 78 79 80 81 82 83 84 85 86 87 88
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

#define SAL_NETDEV_SOCKETOPS_VALID(netdev, pf, ops)                               \
do {                                                                              \
    (pf) = (struct sal_proto_family *) netdev->sal_user_data;                     \
    if ((pf)->skt_ops->ops == RT_NULL){                                           \
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

#define SAL_NETDEV_NETDBOPS_VALID(netdev, pf, ops)                                \
89
    ((netdev) && netdev_is_up(netdev) &&                                          \
90
    ((pf) = (struct sal_proto_family *) (netdev)->sal_user_data) != RT_NULL &&    \
91
    (pf)->netdb_ops->ops)                                                         \
92

93
/**
94
 * SAL (Socket Abstraction Layer) initialize.
95
 *
96 97
 * @return result  0: initialize success
 *                -1: initialize failed        
98 99 100
 */
int sal_init(void)
{
101 102
    int cn;
    
103
    if (init_ok)
104 105 106 107 108
    {
        LOG_D("Socket Abstraction Layer is already initialized.");
        return 0;
    }

109 110 111 112 113 114 115 116 117 118
    /* init sal socket table */
    cn = SOCKET_TABLE_STEP_LEN < SAL_SOCKETS_NUM ? SOCKET_TABLE_STEP_LEN : SAL_SOCKETS_NUM;
    socket_table.max_socket = cn;
    socket_table.sockets = rt_calloc(1, cn * sizeof(struct sal_socket *));
    if (socket_table.sockets == RT_NULL)
    {
        LOG_E("No memory for socket table.\n");
        return -1;
    }
    
119 120 121 122 123 124 125 126 127 128
    /* create sal socket lock */
    rt_mutex_init(&sal_core_lock, "sal_lock", RT_IPC_FLAG_FIFO);

    LOG_I("Socket Abstraction Layer initialize success.");
    init_ok = RT_TRUE;

    return 0;
}
INIT_COMPONENT_EXPORT(sal_init);

129 130
/* check SAL network interface device internet status */
static void check_netdev_internet_up_work(struct rt_work *work, void *work_data)
131
{
132 133 134
#define SAL_INTERNET_VERSION   0x00
#define SAL_INTERNET_BUFF_LEN  12
#define SAL_INTERNET_TIMEOUT   (2 * RT_TICK_PER_SECOND)
135

136 137
#define SAL_INTERNET_HOST      "link.rt-thread.org"
#define SAL_INTERNET_PORT      8101
138

139 140
#define SAL_INTERNET_MONTH_LEN 4
#define SAL_INTERNET_DATE_LEN  16
141

142
    int index, sockfd = -1, result = 0;
143 144 145 146 147 148
    struct sockaddr_in server_addr;
    struct hostent *host;
    struct timeval timeout;
    struct netdev *netdev = (struct netdev *)work_data;
    socklen_t addr_len = sizeof(struct sockaddr_in);
    char send_data[SAL_INTERNET_BUFF_LEN], recv_data = 0;
149
    struct rt_delayed_work *delay_work = (struct rt_delayed_work *)work;
150

151 152 153 154 155 156 157 158
    const char month[][SAL_INTERNET_MONTH_LEN] = {"Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"};
    char date[SAL_INTERNET_DATE_LEN];
    int moth_num = 0;

    struct sal_proto_family *pf = (struct sal_proto_family *) netdev->sal_user_data;
    const struct sal_socket_ops *skt_ops;

    if (work)
159
    {
160
        rt_free(delay_work);
161 162
    }

163 164
    /* get network interface socket operations */
    if (pf == RT_NULL || pf->skt_ops == RT_NULL)
165
    {
166 167
        result = -RT_ERROR;
        goto __exit;
168 169
    }

170 171 172 173 174 175
    host = (struct hostent *) pf->netdb_ops->gethostbyname(SAL_INTERNET_HOST);
    if (host == RT_NULL)
    {
        result = -RT_ERROR;
        goto __exit;
    }
176

H
HubretXie 已提交
177
    skt_ops = pf->skt_ops;
178 179 180 181 182
    if((sockfd = skt_ops->socket(AF_INET, SOCK_DGRAM, 0)) < 0)
    {
        result = -RT_ERROR;
        goto __exit;
    }
183

184 185 186 187
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(SAL_INTERNET_PORT);
    server_addr.sin_addr = *((struct in_addr *)host->h_addr);
    rt_memset(&(server_addr.sin_zero), 0, sizeof(server_addr.sin_zero));
188

189 190
    timeout.tv_sec = SAL_INTERNET_TIMEOUT;
    timeout.tv_usec = 0;
191

192 193 194
    /* set receive and send timeout */
    skt_ops->setsockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, (void *) &timeout, sizeof(timeout));
    skt_ops->setsockopt(sockfd, SOL_SOCKET, SO_SNDTIMEO, (void *) &timeout, sizeof(timeout));
195

196 197 198
    /* get build moth value*/
    rt_memset(date, 0x00, SAL_INTERNET_DATE_LEN);
    rt_snprintf(date, SAL_INTERNET_DATE_LEN, "%s", __DATE__);
199

200
    for (index = 0; index < sizeof(month) / SAL_INTERNET_MONTH_LEN; index++)
201
    {
202
        if (rt_memcmp(date, month[index], SAL_INTERNET_MONTH_LEN - 1) == 0)
203
        {
204 205
            moth_num = index + 1;
            break;
206 207 208
        }
    }

209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
    /* not find build month */
    if (moth_num == 0 || moth_num > sizeof(month) / SAL_INTERNET_MONTH_LEN)
    {
        result = -RT_ERROR;
        goto __exit;
    }

    rt_memset(send_data, 0x00, SAL_INTERNET_BUFF_LEN);
    send_data[0] = SAL_INTERNET_VERSION;
    for (index = 0; index < netdev->hwaddr_len; index++)
    {
        send_data[index + 1] = netdev->hwaddr[index] + moth_num;
    }
    send_data[9] = RT_VERSION;
    send_data[10] = RT_SUBVERSION;
    send_data[11] = RT_REVISION;

    skt_ops->sendto(sockfd, send_data, SAL_INTERNET_BUFF_LEN, 0,
            (struct sockaddr *)&server_addr, sizeof(struct sockaddr));

    result = skt_ops->recvfrom(sockfd, &recv_data, sizeof(recv_data), 0, (struct sockaddr *)&server_addr, &addr_len);
    if (result < 0)
    {
        goto __exit;
    }

    if (recv_data == RT_FALSE)
    {
        result = -RT_ERROR;
        goto __exit;
    }

__exit:
    if (result > 0)
    {
        LOG_D("Set network interface device(%s) internet status up.", netdev->name);
        netdev->flags |= NETDEV_FLAG_INTERNET_UP;       
    }
    else
    {
        LOG_D("Set network interface device(%s) internet status down.", netdev->name);
        netdev->flags &= ~NETDEV_FLAG_INTERNET_UP;
    }

    if (sockfd >= 0)
    {
        skt_ops->closesocket(sockfd);
    }
257 258 259
}

/**
260
 * This function will check SAL network interface device internet status.
261
 *
262
 * @param netdev the network interface device to check
263
 */
264
int sal_check_netdev_internet_up(struct netdev *netdev)
265
{
266 267
    /* workqueue for network connect */
    struct rt_delayed_work *net_work = RT_NULL;
268

269
    RT_ASSERT(netdev);
270

271 272
    net_work = (struct rt_delayed_work *)rt_calloc(1, sizeof(struct rt_delayed_work));
    if (net_work == RT_NULL)
273
    {
274 275
        LOG_W("No memory for network interface device(%s) delay work.", netdev->name);
        return -1;
276 277
    }

278 279 280 281
    rt_delayed_work_init(net_work, check_netdev_internet_up_work, (void *)netdev);
    rt_work_submit(&(net_work->work), RT_TICK_PER_SECOND);
    
    return 0;
282 283 284
}

/**
285
 * This function will register TLS protocol to the global TLS protocol.
286
 *
287
 * @param pt TLS protocol object
288
 *
289
 * @return 0: TLS protocol object register success
290
 */
291 292
#ifdef SAL_USING_TLS
int sal_proto_tls_register(const struct sal_proto_tls *pt)
293
{
294 295
    RT_ASSERT(pt);
    proto_tls = (struct sal_proto_tls *) pt;
296

297
    return 0;
298
}
299
#endif
300 301 302

/**
 * This function will get sal socket object by sal socket descriptor.
303 304 305
 *
 * @param socket sal socket index
 *
306
 * @return sal socket object of the current sal socket index
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
 */
struct sal_socket *sal_get_socket(int socket)
{
    struct sal_socket_table *st = &socket_table;

    if (socket < 0 || socket >= (int) st->max_socket)
    {
        return RT_NULL;
    }

    socket = socket - SAL_SOCKET_OFFSET;
    /* check socket structure valid or not */
    if (st->sockets[socket]->magic != SAL_SOCKET_MAGIC)
    {
        return RT_NULL;
    }

    return st->sockets[socket];
}

/**
328
 * This function will lock sal socket.
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
 *
 * @note please don't invoke it on ISR.
 */
static void sal_lock(void)
{
    rt_err_t result;

    result = rt_mutex_take(&sal_core_lock, RT_WAITING_FOREVER);
    if (result != RT_EOK)
    {
        RT_ASSERT(0);
    }
}

/**
344
 * This function will lock sal socket.
345 346 347 348 349 350 351 352 353
 *
 * @note please don't invoke it on ISR.
 */
static void sal_unlock(void)
{
    rt_mutex_release(&sal_core_lock);
}

/**
354
 * This function will initialize sal socket object and set socket options
355 356 357 358
 *
 * @param family    protocol family
 * @param type      socket type
 * @param protocol  transfer Protocol
359
 * @param res       sal socket object address
360 361 362 363
 *
 * @return  0 : socket initialize success
 *         -1 : input the wrong family
 *         -2 : input the wrong socket type
364
 *         -3 : get network interface failed
365 366 367
 */
static int socket_init(int family, int type, int protocol, struct sal_socket **res)
{
368

369
    struct sal_socket *sock;
370
    struct sal_proto_family *pf;
371 372
    struct netdev *netdv_def = netdev_default;
    struct netdev *netdev = RT_NULL;
373
    rt_bool_t flag = RT_FALSE;
374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389

    if (family < 0 || family > AF_MAX)
    {
        return -1;
    }

    if (type < 0 || type > SOCK_MAX)
    {
        return -2;
    }

    sock = *res;
    sock->domain = family;
    sock->type = type;
    sock->protocol = protocol;

390
    if (netdv_def && netdev_is_up(netdv_def))
391
    {
392 393 394
        /* check default network interface device protocol family */
        pf = (struct sal_proto_family *) netdv_def->sal_user_data;
        if (pf != RT_NULL && pf->skt_ops && (pf->family == family || pf->sec_family == family))
395
        {
396
            sock->netdev = netdv_def;
397
            flag = RT_TRUE;
398 399 400
        }
    }
    
401
    if (flag == RT_FALSE)
402
    {
403 404 405 406 407 408 409 410 411
        /* get network interface device by protocol family */
        netdev = netdev_get_by_family(family);
        if (netdev == RT_NULL)
        {
            LOG_E("not find network interface device by protocol family(%d).", family);
            return -3;
        }

        sock->netdev = netdev;
412 413 414 415 416 417 418 419 420 421 422 423
    }

    return 0;
}

static int socket_alloc(struct sal_socket_table *st, int f_socket)
{
    int idx;

    /* find an empty socket entry */
    for (idx = f_socket; idx < (int) st->max_socket; idx++)
    {
424 425 426
        if (st->sockets[idx] == RT_NULL || 
                st->sockets[idx]->netdev == RT_NULL)
        {
427
            break;
428
        }
429 430 431 432 433 434 435 436
    }

    /* allocate a larger sockte container */
    if (idx == (int) st->max_socket &&  st->max_socket < SAL_SOCKETS_NUM)
    {
        int cnt, index;
        struct sal_socket **sockets;

437 438
        /* increase the number of socket with 4 step length */
        cnt = st->max_socket + SOCKET_TABLE_STEP_LEN;
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
        cnt = cnt > SAL_SOCKETS_NUM ? SAL_SOCKETS_NUM : cnt;

        sockets = rt_realloc(st->sockets, cnt * sizeof(struct sal_socket *));
        if (sockets == RT_NULL)
            goto __result; /* return st->max_socket */

        /* clean the new allocated fds */
        for (index = st->max_socket; index < cnt; index++)
        {
            sockets[index] = RT_NULL;
        }

        st->sockets = sockets;
        st->max_socket = cnt;
    }

    /* allocate  'struct sal_socket' */
    if (idx < (int) st->max_socket && st->sockets[idx] == RT_NULL)
    {
458
        st->sockets[idx] = rt_calloc(1, sizeof(struct sal_socket));
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
        if (st->sockets[idx] == RT_NULL)
        {
            idx = st->max_socket;
        }
    }

__result:
    return idx;
}

static int socket_new(void)
{
    struct sal_socket *sock;
    struct sal_socket_table *st = &socket_table;
    int idx;

    sal_lock();

    /* find an empty sal socket entry */
    idx = socket_alloc(st, 0);

    /* can't find an empty sal socket entry */
    if (idx == (int) st->max_socket)
    {
        idx = -(1 + SAL_SOCKET_OFFSET);
        goto __result;
    }

    sock = st->sockets[idx];
    sock->socket = idx + SAL_SOCKET_OFFSET;
    sock->magic = SAL_SOCKET_MAGIC;
490
    sock->netdev = RT_NULL;
491 492 493 494
    sock->user_data = RT_NULL;
#ifdef SAL_USING_TLS
    sock->user_data_tls = RT_NULL;
#endif
495 496 497 498 499 500 501 502 503 504

__result:
    sal_unlock();
    return idx + SAL_SOCKET_OFFSET;
}

int sal_accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
{
    int new_socket;
    struct sal_socket *sock;
505
    struct sal_proto_family *pf; 
506

507 508
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
509

510 511 512 513
    /* check the network interface socket operations */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, accept);
  
    new_socket = pf->skt_ops->accept((int) sock->user_data, addr, addrlen);
514 515 516
    if (new_socket != -1)
    {
        int retval;
517
        int new_sal_socket;
518 519 520
        struct sal_socket *new_sock;

        /* allocate a new socket structure and registered socket options */
521 522
        new_sal_socket = socket_new();
        if (new_sal_socket < 0)
523
        {
524
            pf->skt_ops->closesocket(new_socket);
525 526
            return -1;
        }
527
        new_sock = sal_get_socket(new_sal_socket);
528 529 530 531

        retval = socket_init(sock->domain, sock->type, sock->protocol, &new_sock);
        if (retval < 0)
        {
532
            pf->skt_ops->closesocket(new_socket);
533
            rt_memset(new_sock, 0x00, sizeof(struct sal_socket));
534 535 536 537
            LOG_E("New socket registered failed, return error %d.", retval);
            return -1;
        }

538
        /* socket structure user_data used to store the acquired new socket */
539 540
        new_sock->user_data = (void *) new_socket;

541
        return new_sal_socket;
542 543 544 545 546
    }

    return -1;
}

547 548 549 550 551
static void sal_sockaddr_to_ipaddr(const struct sockaddr *name, ip_addr_t *local_ipaddr)
{
    const struct sockaddr_in *svr_addr = (const struct sockaddr_in *) name;

#if NETDEV_IPV4 && NETDEV_IPV6
552 553
    local_ipaddr->u_addr.ip4.addr = svr_addr->sin_addr.s_addr;
    local_ipaddr->type = IPADDR_TYPE_V4;
554
#elif NETDEV_IPV4
555
    local_ipaddr->addr = svr_addr->sin_addr.s_addr;
556
#elif NETDEV_IPV6
557 558
#error "not only support IPV6"
#endif /* NETDEV_IPV4 && NETDEV_IPV6*/
559 560
}   

561 562 563
int sal_bind(int socket, const struct sockaddr *name, socklen_t namelen)
{
    struct sal_socket *sock;
564
    struct sal_proto_family *pf;
565
    ip_addr_t input_ipaddr;
566

567 568 569 570 571 572
    RT_ASSERT(name);

    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);

    /* bind network interface by ip address */
573
    sal_sockaddr_to_ipaddr(name, &input_ipaddr);
574 575

    /* check input ipaddr is default netdev ipaddr */
576
    if (!ip_addr_isany_val(input_ipaddr))
577
    {
578 579
        struct sal_proto_family *input_pf = RT_NULL, *local_pf = RT_NULL;
        struct netdev *new_netdev = RT_NULL;
580

581 582 583 584 585
        new_netdev = netdev_get_by_ipaddr(&input_ipaddr);
        if (new_netdev == RT_NULL)
        {
            return -1;
        }
586

587 588 589
        /* get input and local ip address proto_family */
        SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, local_pf, bind);
        SAL_NETDEV_SOCKETOPS_VALID(new_netdev, input_pf, bind);
590

591 592
        /* check the network interface protocol family type */
        if (input_pf->family != local_pf->family)
593
        {
594 595 596 597 598 599 600 601 602 603 604 605
            int new_socket = -1;

            /* protocol family is different, close old socket and create new socket by input ip address */
            local_pf->skt_ops->closesocket(socket);

            new_socket = input_pf->skt_ops->socket(input_pf->family, sock->type, sock->protocol);
            if (new_socket < 0)
            {
                return -1;
            }
            sock->netdev = new_netdev;
            sock->user_data = (void *) new_socket;
606
        }
607
    }
608 609
    
    /* check and get protocol families by the network interface device */
610 611
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, bind);
    return pf->skt_ops->bind((int) sock->user_data, name, namelen);
612 613 614 615 616
}

int sal_shutdown(int socket, int how)
{
    struct sal_socket *sock;
617 618
    struct sal_proto_family *pf;
    int error = 0;
619

620 621
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
622

623 624 625
    /* shutdown operation not nead to check network interface status */
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, shutdown);
626

627
    if (pf->skt_ops->shutdown((int) sock->user_data, how) == 0)
628
    {
629 630 631 632 633 634 635 636 637
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, closesocket))
        {
            if (proto_tls->ops->closesocket(sock->user_data_tls) < 0)
            {
                return -1;
            }
        }
#endif
638 639 640 641 642
        error = 0;
    }
    else
    {
        error = -1;
643 644
    }

645 646 647 648 649
    /* free socket */
    rt_free(sock);
    socket_table.sockets[socket] = RT_NULL;

    return error;
650 651 652 653 654
}

int sal_getpeername(int socket, struct sockaddr *name, socklen_t *namelen)
{
    struct sal_socket *sock;
655
    struct sal_proto_family *pf;
656

657 658
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
659

660 661
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getpeername);
662

663
    return pf->skt_ops->getpeername((int) sock->user_data, name, namelen);
664 665 666 667 668
}

int sal_getsockname(int socket, struct sockaddr *name, socklen_t *namelen)
{
    struct sal_socket *sock;
669
    struct sal_proto_family *pf;
670

671 672
    /* get socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
673

674 675
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getsockname);
676

677
    return pf->skt_ops->getsockname((int) sock->user_data, name, namelen);
678 679 680 681 682
}

int sal_getsockopt(int socket, int level, int optname, void *optval, socklen_t *optlen)
{
    struct sal_socket *sock;
683
    struct sal_proto_family *pf;
684

685 686
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
687

688 689
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getsockopt);
690

691
    return pf->skt_ops->getsockopt((int) sock->user_data, level, optname, optval, optlen);
692 693 694 695 696
}

int sal_setsockopt(int socket, int level, int optname, const void *optval, socklen_t optlen)
{
    struct sal_socket *sock;
697
    struct sal_proto_family *pf;
698

699 700
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
701

702 703
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, setsockopt);
704

705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733
#ifdef SAL_USING_TLS
    if (level == SOL_TLS)
    {
        switch (optname)
        {
        case TLS_CRET_LIST:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_cret_list, optval, optlen);
            break;

        case TLS_CIPHERSUITE_LIST:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_ciphersurite, optval, optlen);
            break;

        case TLS_PEER_VERIFY:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_peer_verify, optval, optlen);
            break;

        case TLS_DTLS_ROLE:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_dtls_role, optval, optlen);
            break;

        default:
            return -1;
        }

        return 0;
    }
    else
    {
734
        return pf->skt_ops->setsockopt((int) sock->user_data, level, optname, optval, optlen);
735 736
    }
#else
737
    return pf->skt_ops->setsockopt((int) sock->user_data, level, optname, optval, optlen);
738
#endif /* SAL_USING_TLS */
739 740 741 742 743
}

int sal_connect(int socket, const struct sockaddr *name, socklen_t namelen)
{
    struct sal_socket *sock;
744
    struct sal_proto_family *pf;
745
    int ret;
746

747 748
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
749

750 751
    /* check the network interface is up status */
    SAL_NETDEV_IS_UP(sock->netdev);
752 753
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, connect);
754

755
    ret = pf->skt_ops->connect((int) sock->user_data, name, namelen);
756 757 758 759 760 761 762 763 764 765 766 767 768
#ifdef SAL_USING_TLS
    if (ret >= 0 && SAL_SOCKOPS_PROTO_TLS_VALID(sock, connect))
    {
        if (proto_tls->ops->connect(sock->user_data_tls) < 0)
        {
            return -1;
        }
        
        return ret;
    }
#endif

    return ret;
769 770 771 772 773
}

int sal_listen(int socket, int backlog)
{
    struct sal_socket *sock;
774
    struct sal_proto_family *pf;
775

776 777
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
778

779 780
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, listen);
781

782
    return pf->skt_ops->listen((int) sock->user_data, backlog);
783 784 785 786 787 788
}

int sal_recvfrom(int socket, void *mem, size_t len, int flags,
             struct sockaddr *from, socklen_t *fromlen)
{
    struct sal_socket *sock;
789
    struct sal_proto_family *pf;
790

791 792
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
793

794 795
    /* check the network interface is up status  */
    SAL_NETDEV_IS_UP(sock->netdev);
796 797
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, recvfrom);
798

799 800 801 802 803 804 805 806 807 808 809 810 811
#ifdef SAL_USING_TLS
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, recv))
    {
        int ret;
        
        if ((ret = proto_tls->ops->recv(sock->user_data_tls, mem, len)) < 0)
        {
            return -1;
        }   
        return ret;
    }
    else
    {
812
        return pf->skt_ops->recvfrom((int) sock->user_data, mem, len, flags, from, fromlen);
813 814
    }
#else
815
    return pf->skt_ops->recvfrom((int) sock->user_data, mem, len, flags, from, fromlen);
816
#endif
817 818 819 820 821 822
}

int sal_sendto(int socket, const void *dataptr, size_t size, int flags,
           const struct sockaddr *to, socklen_t tolen)
{
    struct sal_socket *sock;
823
    struct sal_proto_family *pf;
824

825 826
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
827

828 829
    /* check the network interface is up status  */
    SAL_NETDEV_IS_UP(sock->netdev);
830 831
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, sendto);
832

833 834 835 836 837 838 839 840 841 842 843 844 845
#ifdef SAL_USING_TLS
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, send))
    {
        int ret;
        
        if ((ret = proto_tls->ops->send(sock->user_data_tls, dataptr, size)) < 0)
        {
            return -1;
        }      
        return ret;
    }
    else
    {
846
        return pf->skt_ops->sendto((int) sock->user_data, dataptr, size, flags, to, tolen);
847 848
    }
#else
849
    return pf->skt_ops->sendto((int) sock->user_data, dataptr, size, flags, to, tolen);
850
#endif
851 852 853 854 855 856 857
}

int sal_socket(int domain, int type, int protocol)
{
    int retval;
    int socket, proto_socket;
    struct sal_socket *sock;
858
    struct sal_proto_family *pf;
859 860 861 862 863 864 865

    /* allocate a new socket and registered socket options */
    socket = socket_new();
    if (socket < 0)
    {
        return -1;
    }
866 867

    /* get sal socket object by socket descriptor */
868
    sock = sal_get_socket(socket);
869 870 871 872
    if (sock == RT_NULL)
    {
        return -1;
    }
873

874
    /* Initialize sal socket object */
875 876 877 878 879 880 881
    retval = socket_init(domain, type, protocol, &sock);
    if (retval < 0)
    {
        LOG_E("SAL socket protocol family input failed, return error %d.", retval);
        return -1;
    }

882 883
    /* valid the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, socket);
884

885
    proto_socket = pf->skt_ops->socket(domain, type, protocol);
886 887
    if (proto_socket >= 0)
    {
888 889 890
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, socket))
        {
891
            sock->user_data_tls = proto_tls->ops->socket(socket);
892 893 894 895 896 897
            if (sock->user_data_tls == RT_NULL)
            {
                return -1;
            }
        }
#endif
898 899 900 901 902 903 904 905 906 907
        sock->user_data = (void *) proto_socket;
        return sock->socket;
    }

    return -1;
}

int sal_closesocket(int socket)
{
    struct sal_socket *sock;
908 909
    struct sal_proto_family *pf;
    int error = 0;
910

911 912
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
913

914 915 916
    /* clsoesocket operation not nead to vaild network interface status */
    /* valid the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, socket);
917

918
    if (pf->skt_ops->closesocket((int) sock->user_data) == 0)
919
    {
920 921 922 923 924 925 926 927 928
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, closesocket))
        {
            if (proto_tls->ops->closesocket(sock->user_data_tls) < 0)
            {
                return -1;
            }
        }
#endif
929 930 931 932 933
        error = 0;
    }
    else
    {
        error = -1;
934 935
    }

936 937 938 939 940
    /* free socket */
    rt_free(sock);        
    socket_table.sockets[socket] = RT_NULL;

    return error;
941 942 943 944 945
}

int sal_ioctlsocket(int socket, long cmd, void *arg)
{
    struct sal_socket *sock;
946
    struct sal_proto_family *pf;
947

948 949
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
950

951 952
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, ioctlsocket);
953

954
    return pf->skt_ops->ioctlsocket((int) sock->user_data, cmd, arg);
955 956
}

957
#ifdef SAL_USING_POSIX
958 959 960
int sal_poll(struct dfs_fd *file, struct rt_pollreq *req)
{
    struct sal_socket *sock;
961
    struct sal_proto_family *pf;
962 963
    int socket = (int) file->data;

964 965
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
966

967 968
    /* check the network interface is up status  */
    SAL_NETDEV_IS_UP(sock->netdev);
969 970
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, poll);
971

972
    return pf->skt_ops->poll(file, req);
973
}
974
#endif
975 976 977

struct hostent *sal_gethostbyname(const char *name)
{
978 979
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
980

981
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname))
982
    {
983 984 985 986
        return pf->netdb_ops->gethostbyname(name);
    }
    else
    {
987 988
        /* get the first network interface device with up status */
        netdev = netdev_get_first_by_flags(NETDEV_FLAG_UP);
989
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname))
990
        {
991
            return pf->netdb_ops->gethostbyname(name);
992 993 994 995 996 997 998 999 1000
        }
    }

    return RT_NULL;
}

int sal_gethostbyname_r(const char *name, struct hostent *ret, char *buf,
                size_t buflen, struct hostent **result, int *h_errnop)
{
1001 1002
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1003

1004
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname_r))
1005
    {
1006 1007 1008 1009
        return pf->netdb_ops->gethostbyname_r(name, ret, buf, buflen, result, h_errnop);
    }
    else
    {
1010 1011
        /* get the first network interface device with up status */
        netdev = netdev_get_first_by_flags(NETDEV_FLAG_UP);
1012
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname_r))
1013
        {
1014
            return pf->netdb_ops->gethostbyname_r(name, ret, buf, buflen, result, h_errnop);
1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
        }
    }

    return -1;
}

int sal_getaddrinfo(const char *nodename,
       const char *servname,
       const struct addrinfo *hints,
       struct addrinfo **res)
{
1026 1027
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1028

1029
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, getaddrinfo))
1030
    {
1031 1032 1033 1034
        return pf->netdb_ops->getaddrinfo(nodename, servname, hints, res);
    }
    else
    {
1035 1036
        /* get the first network interface device with up status */
        netdev = netdev_get_first_by_flags(NETDEV_FLAG_UP);
1037
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, getaddrinfo))
1038
        {
1039
            return pf->netdb_ops->getaddrinfo(nodename, servname, hints, res);
1040 1041 1042 1043 1044
        }
    }

    return -1;
}
1045 1046 1047

void sal_freeaddrinfo(struct addrinfo *ai)
{
1048 1049
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1050

1051
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, freeaddrinfo))
1052
    {
1053 1054 1055 1056
        pf->netdb_ops->freeaddrinfo(ai);
    }
    else
    {
1057 1058
        /* get the first network interface device with up status */
        netdev = netdev_get_first_by_flags(NETDEV_FLAG_UP);
1059
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, freeaddrinfo))
1060
        {
1061
            pf->netdb_ops->freeaddrinfo(ai);
1062 1063 1064
        }
    }
}
1065