sal_socket.c 29.8 KB
Newer Older
1
/*
2
 * Copyright (c) 2006-2018, RT-Thread Development Team
3
 *
4
 * SPDX-License-Identifier: Apache-2.0
5 6 7 8
 *
 * Change Logs:
 * Date           Author       Notes
 * 2018-05-23     ChenYong     First version
9
 * 2018-11-12     ChenYong     Add TLS support
10 11 12 13
 */

#include <rtthread.h>
#include <rthw.h>
14
#include <sys/time.h>
15 16 17

#include <sal_socket.h>
#include <sal_netdb.h>
18 19 20
#ifdef SAL_USING_TLS
#include <sal_tls.h>
#endif
21
#include <sal.h>
22
#include <netdev.h>
23

24 25 26 27 28 29 30 31 32
#include <ipc/workqueue.h>

/* check system workqueue stack size */
#if RT_SYSTEM_WORKQUEUE_STACKSIZE < 1536
#error "The system workqueue stack size must more than 1536 bytes"
#endif

#define DBG_TAG                        "sal.skt"
#define DBG_LVL                        DBG_INFO
33 34
#include <rtdbg.h>

35 36
#define SOCKET_TABLE_STEP_LEN          4

37 38 39 40 41 42 43
/* the socket table used to dynamic allocate sockets */
struct sal_socket_table
{
    uint32_t max_socket;
    struct sal_socket **sockets;
};

44 45 46 47 48
#ifdef SAL_USING_TLS
/* The global TLS protocol options */
static struct sal_proto_tls *proto_tls;
#endif

49 50 51 52 53
/* The global socket table */
static struct sal_socket_table socket_table;
static struct rt_mutex sal_core_lock;
static rt_bool_t init_ok = RT_FALSE;

54 55 56 57 58
#define IS_SOCKET_PROTO_TLS(sock)                (((sock)->protocol == PROTOCOL_TLS) || \
                                                 ((sock)->protocol == PROTOCOL_DTLS))
#define SAL_SOCKOPS_PROTO_TLS_VALID(sock, name)  (proto_tls && (proto_tls->ops->name) && IS_SOCKET_PROTO_TLS(sock))

#define SAL_SOCKOPT_PROTO_TLS_EXEC(sock, name, optval, optlen)                    \
59 60
do {                                                                              \
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, name)){                                 \
61 62 63 64
        return proto_tls->ops->name((sock)->user_data_tls, (optval), (optlen));   \
    }                                                                             \
}while(0)                                                                         \

65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
#define SAL_SOCKET_OBJ_GET(sock, socket)                                          \
do {                                                                              \
    (sock) = sal_get_socket(socket);                                              \
    if ((sock) == RT_NULL) {                                                      \
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

#define SAL_NETDEV_IS_COMMONICABLE(netdev)                                        \
do {                                                                              \
    if (!netdev_is_up(netdev) || !netdev_is_link_up(netdev)){                     \
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

#define SAL_NETDEV_SOCKETOPS_VALID(netdev, pf, ops)                               \
do {                                                                              \
    (pf) = (struct sal_proto_family *) netdev->sal_user_data;                     \
    if ((pf)->skt_ops->ops == RT_NULL){                                           \
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

#define SAL_NETDEV_NETDBOPS_VALID(netdev, pf, ops)                                \
((netdev) && netdev_is_up(netdev) && netdev_is_link_up(netdev) &&                 \
    ((pf) = (struct sal_proto_family *) (netdev)->sal_user_data) != RT_NULL &&    \
        (pf)->netdb_ops->ops)                                                     \
92

93
/**
94
 * SAL (Socket Abstraction Layer) initialize.
95
 *
96 97
 * @return result  0: initialize success
 *                -1: initialize failed        
98 99 100
 */
int sal_init(void)
{
101 102
    int cn;
    
103
    if (init_ok)
104 105 106 107 108
    {
        LOG_D("Socket Abstraction Layer is already initialized.");
        return 0;
    }

109 110 111 112 113 114 115 116 117 118
    /* init sal socket table */
    cn = SOCKET_TABLE_STEP_LEN < SAL_SOCKETS_NUM ? SOCKET_TABLE_STEP_LEN : SAL_SOCKETS_NUM;
    socket_table.max_socket = cn;
    socket_table.sockets = rt_calloc(1, cn * sizeof(struct sal_socket *));
    if (socket_table.sockets == RT_NULL)
    {
        LOG_E("No memory for socket table.\n");
        return -1;
    }
    
119 120 121 122 123 124 125 126 127 128
    /* create sal socket lock */
    rt_mutex_init(&sal_core_lock, "sal_lock", RT_IPC_FLAG_FIFO);

    LOG_I("Socket Abstraction Layer initialize success.");
    init_ok = RT_TRUE;

    return 0;
}
INIT_COMPONENT_EXPORT(sal_init);

129 130
/* check SAL network interface device internet status */
static void check_netdev_internet_up_work(struct rt_work *work, void *work_data)
131
{
132 133 134
#define SAL_INTERNET_VERSION   0x00
#define SAL_INTERNET_BUFF_LEN  12
#define SAL_INTERNET_TIMEOUT   (2 * RT_TICK_PER_SECOND)
135

136 137
#define SAL_INTERNET_HOST      "link.rt-thread.org"
#define SAL_INTERNET_PORT      8101
138

139 140
#define SAL_INTERNET_MONTH_LEN 4
#define SAL_INTERNET_DATE_LEN  16
141

142
    int index, sockfd = -1, result = 0;
143 144 145 146 147 148
    struct sockaddr_in server_addr;
    struct hostent *host;
    struct timeval timeout;
    struct netdev *netdev = (struct netdev *)work_data;
    socklen_t addr_len = sizeof(struct sockaddr_in);
    char send_data[SAL_INTERNET_BUFF_LEN], recv_data = 0;
149
    struct rt_delayed_work *delay_work = (struct rt_delayed_work *)work;
150

151 152 153 154 155 156 157 158
    const char month[][SAL_INTERNET_MONTH_LEN] = {"Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"};
    char date[SAL_INTERNET_DATE_LEN];
    int moth_num = 0;

    struct sal_proto_family *pf = (struct sal_proto_family *) netdev->sal_user_data;
    const struct sal_socket_ops *skt_ops;

    if (work)
159
    {
160
        rt_free(delay_work);
161 162
    }

163 164
    /* get network interface socket operations */
    if (pf == RT_NULL || pf->skt_ops == RT_NULL)
165
    {
166 167
        result = -RT_ERROR;
        goto __exit;
168 169
    }

170 171 172 173 174 175
    host = (struct hostent *) pf->netdb_ops->gethostbyname(SAL_INTERNET_HOST);
    if (host == RT_NULL)
    {
        result = -RT_ERROR;
        goto __exit;
    }
176

H
HubretXie 已提交
177
    skt_ops = pf->skt_ops;
178 179 180 181 182
    if((sockfd = skt_ops->socket(AF_INET, SOCK_DGRAM, 0)) < 0)
    {
        result = -RT_ERROR;
        goto __exit;
    }
183

184 185 186 187
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(SAL_INTERNET_PORT);
    server_addr.sin_addr = *((struct in_addr *)host->h_addr);
    rt_memset(&(server_addr.sin_zero), 0, sizeof(server_addr.sin_zero));
188

189 190
    timeout.tv_sec = SAL_INTERNET_TIMEOUT;
    timeout.tv_usec = 0;
191

192 193 194
    /* set receive and send timeout */
    skt_ops->setsockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, (void *) &timeout, sizeof(timeout));
    skt_ops->setsockopt(sockfd, SOL_SOCKET, SO_SNDTIMEO, (void *) &timeout, sizeof(timeout));
195

196 197 198
    /* get build moth value*/
    rt_memset(date, 0x00, SAL_INTERNET_DATE_LEN);
    rt_snprintf(date, SAL_INTERNET_DATE_LEN, "%s", __DATE__);
199

200
    for (index = 0; index < sizeof(month) / SAL_INTERNET_MONTH_LEN; index++)
201
    {
202
        if (rt_memcmp(date, month[index], SAL_INTERNET_MONTH_LEN - 1) == 0)
203
        {
204 205
            moth_num = index + 1;
            break;
206 207 208
        }
    }

209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
    /* not find build month */
    if (moth_num == 0 || moth_num > sizeof(month) / SAL_INTERNET_MONTH_LEN)
    {
        result = -RT_ERROR;
        goto __exit;
    }

    rt_memset(send_data, 0x00, SAL_INTERNET_BUFF_LEN);
    send_data[0] = SAL_INTERNET_VERSION;
    for (index = 0; index < netdev->hwaddr_len; index++)
    {
        send_data[index + 1] = netdev->hwaddr[index] + moth_num;
    }
    send_data[9] = RT_VERSION;
    send_data[10] = RT_SUBVERSION;
    send_data[11] = RT_REVISION;

    skt_ops->sendto(sockfd, send_data, SAL_INTERNET_BUFF_LEN, 0,
            (struct sockaddr *)&server_addr, sizeof(struct sockaddr));

    result = skt_ops->recvfrom(sockfd, &recv_data, sizeof(recv_data), 0, (struct sockaddr *)&server_addr, &addr_len);
    if (result < 0)
    {
        goto __exit;
    }

    if (recv_data == RT_FALSE)
    {
        result = -RT_ERROR;
        goto __exit;
    }

__exit:
    if (result > 0)
    {
        LOG_D("Set network interface device(%s) internet status up.", netdev->name);
        netdev->flags |= NETDEV_FLAG_INTERNET_UP;       
    }
    else
    {
        LOG_D("Set network interface device(%s) internet status down.", netdev->name);
        netdev->flags &= ~NETDEV_FLAG_INTERNET_UP;
    }

    if (sockfd >= 0)
    {
        skt_ops->closesocket(sockfd);
    }
257 258 259
}

/**
260
 * This function will check SAL network interface device internet status.
261
 *
262
 * @param netdev the network interface device to check
263
 */
264
int sal_check_netdev_internet_up(struct netdev *netdev)
265
{
266 267
    /* workqueue for network connect */
    struct rt_delayed_work *net_work = RT_NULL;
268

269
    RT_ASSERT(netdev);
270

271 272
    net_work = (struct rt_delayed_work *)rt_calloc(1, sizeof(struct rt_delayed_work));
    if (net_work == RT_NULL)
273
    {
274 275
        LOG_W("No memory for network interface device(%s) delay work.", netdev->name);
        return -1;
276 277
    }

278 279 280 281
    rt_delayed_work_init(net_work, check_netdev_internet_up_work, (void *)netdev);
    rt_work_submit(&(net_work->work), RT_TICK_PER_SECOND);
    
    return 0;
282 283 284
}

/**
285
 * This function will register TLS protocol to the global TLS protocol.
286
 *
287
 * @param pt TLS protocol object
288
 *
289
 * @return 0: TLS protocol object register success
290
 */
291 292
#ifdef SAL_USING_TLS
int sal_proto_tls_register(const struct sal_proto_tls *pt)
293
{
294 295
    RT_ASSERT(pt);
    proto_tls = (struct sal_proto_tls *) pt;
296

297
    return 0;
298
}
299
#endif
300 301 302

/**
 * This function will get sal socket object by sal socket descriptor.
303 304 305
 *
 * @param socket sal socket index
 *
306
 * @return sal socket object of the current sal socket index
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
 */
struct sal_socket *sal_get_socket(int socket)
{
    struct sal_socket_table *st = &socket_table;

    if (socket < 0 || socket >= (int) st->max_socket)
    {
        return RT_NULL;
    }

    socket = socket - SAL_SOCKET_OFFSET;
    /* check socket structure valid or not */
    if (st->sockets[socket]->magic != SAL_SOCKET_MAGIC)
    {
        return RT_NULL;
    }

    return st->sockets[socket];
}

/**
328
 * This function will lock sal socket.
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
 *
 * @note please don't invoke it on ISR.
 */
static void sal_lock(void)
{
    rt_err_t result;

    result = rt_mutex_take(&sal_core_lock, RT_WAITING_FOREVER);
    if (result != RT_EOK)
    {
        RT_ASSERT(0);
    }
}

/**
344
 * This function will lock sal socket.
345 346 347 348 349 350 351 352 353
 *
 * @note please don't invoke it on ISR.
 */
static void sal_unlock(void)
{
    rt_mutex_release(&sal_core_lock);
}

/**
354
 * This function will initialize sal socket object and set socket options
355 356 357 358
 *
 * @param family    protocol family
 * @param type      socket type
 * @param protocol  transfer Protocol
359
 * @param res       sal socket object address
360 361 362 363
 *
 * @return  0 : socket initialize success
 *         -1 : input the wrong family
 *         -2 : input the wrong socket type
364
 *         -3 : get network interface failed
365 366 367
 */
static int socket_init(int family, int type, int protocol, struct sal_socket **res)
{
368

369
    struct sal_socket *sock;
370
    struct sal_proto_family *pf;
371 372
    struct netdev *netdv_def = netdev_default;
    struct netdev *netdev = RT_NULL;
373
    rt_bool_t flag = RT_FALSE;
374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389

    if (family < 0 || family > AF_MAX)
    {
        return -1;
    }

    if (type < 0 || type > SOCK_MAX)
    {
        return -2;
    }

    sock = *res;
    sock->domain = family;
    sock->type = type;
    sock->protocol = protocol;

390
    /* get socket operations from network interface device */
391 392 393 394 395 396
    if (netdv_def == RT_NULL)
    {
        LOG_E("not find default network interface device for socket create.");
        return -3;
    }

H
HubretXie 已提交
397
    if (netdev_is_up(netdv_def) && netdev_is_link_up(netdv_def))
398
    {
399 400 401
        /* check default network interface device protocol family */
        pf = (struct sal_proto_family *) netdv_def->sal_user_data;
        if (pf != RT_NULL && pf->skt_ops && (pf->family == family || pf->sec_family == family))
402
        {
403
            sock->netdev = netdv_def;
404
            flag = RT_TRUE;
405 406 407
        }
    }
    
408
    if (flag == RT_FALSE)
409
    {
410 411 412 413 414 415 416 417 418
        /* get network interface device by protocol family */
        netdev = netdev_get_by_family(family);
        if (netdev == RT_NULL)
        {
            LOG_E("not find network interface device by protocol family(%d).", family);
            return -3;
        }

        sock->netdev = netdev;
419 420 421 422 423 424 425 426 427 428 429 430
    }

    return 0;
}

static int socket_alloc(struct sal_socket_table *st, int f_socket)
{
    int idx;

    /* find an empty socket entry */
    for (idx = f_socket; idx < (int) st->max_socket; idx++)
    {
431 432 433
        if (st->sockets[idx] == RT_NULL || 
                st->sockets[idx]->netdev == RT_NULL)
        {
434
            break;
435
        }
436 437 438 439 440 441 442 443
    }

    /* allocate a larger sockte container */
    if (idx == (int) st->max_socket &&  st->max_socket < SAL_SOCKETS_NUM)
    {
        int cnt, index;
        struct sal_socket **sockets;

444 445
        /* increase the number of socket with 4 step length */
        cnt = st->max_socket + SOCKET_TABLE_STEP_LEN;
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
        cnt = cnt > SAL_SOCKETS_NUM ? SAL_SOCKETS_NUM : cnt;

        sockets = rt_realloc(st->sockets, cnt * sizeof(struct sal_socket *));
        if (sockets == RT_NULL)
            goto __result; /* return st->max_socket */

        /* clean the new allocated fds */
        for (index = st->max_socket; index < cnt; index++)
        {
            sockets[index] = RT_NULL;
        }

        st->sockets = sockets;
        st->max_socket = cnt;
    }

    /* allocate  'struct sal_socket' */
    if (idx < (int) st->max_socket && st->sockets[idx] == RT_NULL)
    {
465
        st->sockets[idx] = rt_calloc(1, sizeof(struct sal_socket));
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
        if (st->sockets[idx] == RT_NULL)
        {
            idx = st->max_socket;
        }
    }

__result:
    return idx;
}

static int socket_new(void)
{
    struct sal_socket *sock;
    struct sal_socket_table *st = &socket_table;
    int idx;

    sal_lock();

    /* find an empty sal socket entry */
    idx = socket_alloc(st, 0);

    /* can't find an empty sal socket entry */
    if (idx == (int) st->max_socket)
    {
        idx = -(1 + SAL_SOCKET_OFFSET);
        goto __result;
    }

    sock = st->sockets[idx];
    sock->socket = idx + SAL_SOCKET_OFFSET;
    sock->magic = SAL_SOCKET_MAGIC;
497
    sock->netdev = RT_NULL;
498 499 500 501
    sock->user_data = RT_NULL;
#ifdef SAL_USING_TLS
    sock->user_data_tls = RT_NULL;
#endif
502 503 504 505 506 507 508 509 510 511

__result:
    sal_unlock();
    return idx + SAL_SOCKET_OFFSET;
}

int sal_accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
{
    int new_socket;
    struct sal_socket *sock;
512
    struct sal_proto_family *pf; 
513

514 515
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
516

517 518 519 520
    /* check the network interface socket operations */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, accept);
  
    new_socket = pf->skt_ops->accept((int) sock->user_data, addr, addrlen);
521 522 523
    if (new_socket != -1)
    {
        int retval;
524
        int new_sal_socket;
525 526 527
        struct sal_socket *new_sock;

        /* allocate a new socket structure and registered socket options */
528 529
        new_sal_socket = socket_new();
        if (new_sal_socket < 0)
530
        {
531
            pf->skt_ops->closesocket(new_socket);
532 533
            return -1;
        }
534
        new_sock = sal_get_socket(new_sal_socket);
535 536 537 538

        retval = socket_init(sock->domain, sock->type, sock->protocol, &new_sock);
        if (retval < 0)
        {
539
            pf->skt_ops->closesocket(new_socket);
540
            rt_memset(new_sock, 0x00, sizeof(struct sal_socket));
541 542 543 544
            LOG_E("New socket registered failed, return error %d.", retval);
            return -1;
        }

545
        /* socket structure user_data used to store the acquired new socket */
546 547
        new_sock->user_data = (void *) new_socket;

548
        return new_sal_socket;
549 550 551 552 553
    }

    return -1;
}

554 555 556 557 558 559 560 561 562 563 564 565 566
static void sal_sockaddr_to_ipaddr(const struct sockaddr *name, ip_addr_t *local_ipaddr)
{
    const struct sockaddr_in *svr_addr = (const struct sockaddr_in *) name;

#if NETDEV_IPV4 && NETDEV_IPV6
    (*local_ipaddr).u_addr.ip4.addr = svr_addr->sin_addr.s_addr;
#elif NETDEV_IPV4
    (*local_ipaddr).addr = svr_addr->sin_addr.s_addr;
#elif NETDEV_IPV6
    LOG_E("not support IPV6");
#endif /* SAL_IPV4 && SAL_IPV6*/
}   

567 568 569
int sal_bind(int socket, const struct sockaddr *name, socklen_t namelen)
{
    struct sal_socket *sock;
570 571 572
    struct sal_proto_family *pf;
    struct netdev *new_netdev;
    ip_addr_t local_addr;
573

574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
    RT_ASSERT(name);

    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);

    /* bind network interface by ip address */
    sal_sockaddr_to_ipaddr(name,  &local_addr);

    /* check input ipaddr is default netdev ipaddr */
    if (local_addr.addr == INADDR_ANY)
    {
        SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, bind);
        return pf->skt_ops->bind((int) sock->user_data, name, namelen);
    }

    new_netdev = netdev_get_by_ipaddr(&local_addr);
    if (new_netdev == RT_NULL)
591
    {
592
        LOG_E("Not find network interface device ipaddr(%s).", inet_ntoa(local_addr));
593 594 595
        return -1;
    }

596 597
    /* change network interface device parameter in sal socket object */
    if (sock->netdev != new_netdev)
598
    {
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
        struct sal_proto_family *old_pf, *new_pf;
        int new_socket = 0;

        /* close old netdev socket */
        SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, old_pf, closesocket);
        old_pf->skt_ops->closesocket(socket);

        /* open new netdev socket */
        SAL_NETDEV_SOCKETOPS_VALID(new_netdev, new_pf, socket);
        new_socket = new_pf->skt_ops->socket(sock->domain, sock->type, sock->protocol);
        if (new_socket < 0)
        {
            return -1;
        }
        sock->netdev = new_netdev;
        sock->user_data = (void *) new_socket;
615 616
    }

617 618 619
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, bind);
    return pf->skt_ops->bind((int) sock->user_data, name, namelen);
620 621 622 623 624
}

int sal_shutdown(int socket, int how)
{
    struct sal_socket *sock;
625 626
    struct sal_proto_family *pf;
    int error = 0;
627

628 629
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
630

631 632 633
    /* shutdown operation not nead to check network interface status */
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, shutdown);
634

635
    if (pf->skt_ops->shutdown((int) sock->user_data, how) == 0)
636
    {
637 638 639 640 641 642 643 644 645
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, closesocket))
        {
            if (proto_tls->ops->closesocket(sock->user_data_tls) < 0)
            {
                return -1;
            }
        }
#endif
646 647 648 649 650
        error = 0;
    }
    else
    {
        error = -1;
651 652
    }

653 654 655 656 657
    /* free socket */
    rt_free(sock);
    socket_table.sockets[socket] = RT_NULL;

    return error;
658 659 660 661 662
}

int sal_getpeername(int socket, struct sockaddr *name, socklen_t *namelen)
{
    struct sal_socket *sock;
663
    struct sal_proto_family *pf;
664

665 666
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
667

668 669
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getpeername);
670

671
    return pf->skt_ops->getpeername((int) sock->user_data, name, namelen);
672 673 674 675 676
}

int sal_getsockname(int socket, struct sockaddr *name, socklen_t *namelen)
{
    struct sal_socket *sock;
677
    struct sal_proto_family *pf;
678

679 680
    /* get socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
681

682 683
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getsockname);
684

685
    return pf->skt_ops->getsockname((int) sock->user_data, name, namelen);
686 687 688 689 690
}

int sal_getsockopt(int socket, int level, int optname, void *optval, socklen_t *optlen)
{
    struct sal_socket *sock;
691
    struct sal_proto_family *pf;
692

693 694
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
695

696 697
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getsockopt);
698

699
    return pf->skt_ops->getsockopt((int) sock->user_data, level, optname, optval, optlen);
700 701 702 703 704
}

int sal_setsockopt(int socket, int level, int optname, const void *optval, socklen_t optlen)
{
    struct sal_socket *sock;
705
    struct sal_proto_family *pf;
706

707 708
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
709

710 711
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, setsockopt);
712

713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741
#ifdef SAL_USING_TLS
    if (level == SOL_TLS)
    {
        switch (optname)
        {
        case TLS_CRET_LIST:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_cret_list, optval, optlen);
            break;

        case TLS_CIPHERSUITE_LIST:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_ciphersurite, optval, optlen);
            break;

        case TLS_PEER_VERIFY:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_peer_verify, optval, optlen);
            break;

        case TLS_DTLS_ROLE:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_dtls_role, optval, optlen);
            break;

        default:
            return -1;
        }

        return 0;
    }
    else
    {
742
        return pf->skt_ops->setsockopt((int) sock->user_data, level, optname, optval, optlen);
743 744
    }
#else
745
    return pf->skt_ops->setsockopt((int) sock->user_data, level, optname, optval, optlen);
746
#endif /* SAL_USING_TLS */
747 748 749 750 751
}

int sal_connect(int socket, const struct sockaddr *name, socklen_t namelen)
{
    struct sal_socket *sock;
752
    struct sal_proto_family *pf;
753
    int ret;
754

755 756
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
757

758 759 760 761
    /* check the network interface is commonicable  */
    SAL_NETDEV_IS_COMMONICABLE(sock->netdev);
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, connect);
762

763
    ret = pf->skt_ops->connect((int) sock->user_data, name, namelen);
764 765 766 767 768 769 770 771 772 773 774 775 776
#ifdef SAL_USING_TLS
    if (ret >= 0 && SAL_SOCKOPS_PROTO_TLS_VALID(sock, connect))
    {
        if (proto_tls->ops->connect(sock->user_data_tls) < 0)
        {
            return -1;
        }
        
        return ret;
    }
#endif

    return ret;
777 778 779 780 781
}

int sal_listen(int socket, int backlog)
{
    struct sal_socket *sock;
782
    struct sal_proto_family *pf;
783

784 785
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
786

787 788
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, listen);
789

790
    return pf->skt_ops->listen((int) sock->user_data, backlog);
791 792 793 794 795 796
}

int sal_recvfrom(int socket, void *mem, size_t len, int flags,
             struct sockaddr *from, socklen_t *fromlen)
{
    struct sal_socket *sock;
797
    struct sal_proto_family *pf;
798

799 800
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
801

802 803 804 805
    /* check the network interface is commonicable  */
    SAL_NETDEV_IS_COMMONICABLE(sock->netdev);
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, recvfrom);
806

807 808 809 810 811 812 813 814 815 816 817 818 819
#ifdef SAL_USING_TLS
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, recv))
    {
        int ret;
        
        if ((ret = proto_tls->ops->recv(sock->user_data_tls, mem, len)) < 0)
        {
            return -1;
        }   
        return ret;
    }
    else
    {
820
        return pf->skt_ops->recvfrom((int) sock->user_data, mem, len, flags, from, fromlen);
821 822
    }
#else
823
    return pf->skt_ops->recvfrom((int) sock->user_data, mem, len, flags, from, fromlen);
824
#endif
825 826 827 828 829 830
}

int sal_sendto(int socket, const void *dataptr, size_t size, int flags,
           const struct sockaddr *to, socklen_t tolen)
{
    struct sal_socket *sock;
831
    struct sal_proto_family *pf;
832

833 834
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
835

836 837 838 839
    /* check the network interface is commonicable  */
    SAL_NETDEV_IS_COMMONICABLE(sock->netdev);
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, sendto);
840

841 842 843 844 845 846 847 848 849 850 851 852 853
#ifdef SAL_USING_TLS
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, send))
    {
        int ret;
        
        if ((ret = proto_tls->ops->send(sock->user_data_tls, dataptr, size)) < 0)
        {
            return -1;
        }      
        return ret;
    }
    else
    {
854
        return pf->skt_ops->sendto((int) sock->user_data, dataptr, size, flags, to, tolen);
855 856
    }
#else
857
    return pf->skt_ops->sendto((int) sock->user_data, dataptr, size, flags, to, tolen);
858
#endif
859 860 861 862 863 864 865
}

int sal_socket(int domain, int type, int protocol)
{
    int retval;
    int socket, proto_socket;
    struct sal_socket *sock;
866
    struct sal_proto_family *pf;
867 868 869 870 871 872 873

    /* allocate a new socket and registered socket options */
    socket = socket_new();
    if (socket < 0)
    {
        return -1;
    }
874 875

    /* get sal socket object by socket descriptor */
876
    sock = sal_get_socket(socket);
877 878 879 880
    if (sock == RT_NULL)
    {
        return -1;
    }
881

882
    /* Initialize sal socket object */
883 884 885 886 887 888 889
    retval = socket_init(domain, type, protocol, &sock);
    if (retval < 0)
    {
        LOG_E("SAL socket protocol family input failed, return error %d.", retval);
        return -1;
    }

890 891
    /* valid the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, socket);
892

893
    proto_socket = pf->skt_ops->socket(domain, type, protocol);
894 895
    if (proto_socket >= 0)
    {
896 897 898
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, socket))
        {
899
            sock->user_data_tls = proto_tls->ops->socket(socket);
900 901 902 903 904 905
            if (sock->user_data_tls == RT_NULL)
            {
                return -1;
            }
        }
#endif
906 907 908 909 910 911 912 913 914 915
        sock->user_data = (void *) proto_socket;
        return sock->socket;
    }

    return -1;
}

int sal_closesocket(int socket)
{
    struct sal_socket *sock;
916 917
    struct sal_proto_family *pf;
    int error = 0;
918

919 920
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
921

922 923 924
    /* clsoesocket operation not nead to vaild network interface status */
    /* valid the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, socket);
925

926
    if (pf->skt_ops->closesocket((int) sock->user_data) == 0)
927
    {
928 929 930 931 932 933 934 935 936
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, closesocket))
        {
            if (proto_tls->ops->closesocket(sock->user_data_tls) < 0)
            {
                return -1;
            }
        }
#endif
937 938 939 940 941
        error = 0;
    }
    else
    {
        error = -1;
942 943
    }

944 945 946 947 948
    /* free socket */
    rt_free(sock);        
    socket_table.sockets[socket] = RT_NULL;

    return error;
949 950 951 952 953
}

int sal_ioctlsocket(int socket, long cmd, void *arg)
{
    struct sal_socket *sock;
954
    struct sal_proto_family *pf;
955

956 957
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
958

959 960
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, ioctlsocket);
961

962
    return pf->skt_ops->ioctlsocket((int) sock->user_data, cmd, arg);
963 964
}

965
#ifdef SAL_USING_POSIX
966 967 968
int sal_poll(struct dfs_fd *file, struct rt_pollreq *req)
{
    struct sal_socket *sock;
969
    struct sal_proto_family *pf;
970 971
    int socket = (int) file->data;

972 973
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
974

975 976 977 978
    /* check the network interface is commonicable  */
    SAL_NETDEV_IS_COMMONICABLE(sock->netdev);
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, poll);
979

980
    return pf->skt_ops->poll(file, req);
981
}
982
#endif
983 984 985

struct hostent *sal_gethostbyname(const char *name)
{
986 987
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
988

989
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname))
990
    {
991 992 993 994 995 996 997
        return pf->netdb_ops->gethostbyname(name);
    }
    else
    {
        /* get the first network interface device with the link up status */
        netdev = netdev_get_first_link_up();
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname))
998
        {
999
            return pf->netdb_ops->gethostbyname(name);
1000 1001 1002 1003 1004 1005 1006 1007 1008
        }
    }

    return RT_NULL;
}

int sal_gethostbyname_r(const char *name, struct hostent *ret, char *buf,
                size_t buflen, struct hostent **result, int *h_errnop)
{
1009 1010
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1011

1012
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname_r))
1013
    {
1014 1015 1016 1017 1018 1019 1020
        return pf->netdb_ops->gethostbyname_r(name, ret, buf, buflen, result, h_errnop);
    }
    else
    {
        /* get the first network interface device with the link up status */
        netdev = netdev_get_first_link_up();
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname_r))
1021
        {
1022
            return pf->netdb_ops->gethostbyname_r(name, ret, buf, buflen, result, h_errnop);
1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033
        }
    }

    return -1;
}

int sal_getaddrinfo(const char *nodename,
       const char *servname,
       const struct addrinfo *hints,
       struct addrinfo **res)
{
1034 1035
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1036

1037
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, getaddrinfo))
1038
    {
1039 1040 1041 1042 1043 1044 1045
        return pf->netdb_ops->getaddrinfo(nodename, servname, hints, res);
    }
    else
    {
        /* get the first network interface device with the link up status */
        netdev = netdev_get_first_link_up();
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, getaddrinfo))
1046
        {
1047
            return pf->netdb_ops->getaddrinfo(nodename, servname, hints, res);
1048 1049 1050 1051 1052
        }
    }

    return -1;
}
1053 1054 1055

void sal_freeaddrinfo(struct addrinfo *ai)
{
1056 1057
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1058

1059
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, freeaddrinfo))
1060
    {
1061 1062 1063 1064 1065 1066 1067
        pf->netdb_ops->freeaddrinfo(ai);
    }
    else
    {
        /* get the first network interface device with the link up status */
        netdev = netdev_get_first_link_up();
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, freeaddrinfo))
1068
        {
1069
            pf->netdb_ops->freeaddrinfo(ai);
1070 1071 1072
        }
    }
}
1073