sal_socket.c 30.8 KB
Newer Older
1
/*
2
 * Copyright (c) 2006-2018, RT-Thread Development Team
3
 *
4
 * SPDX-License-Identifier: Apache-2.0
5 6 7 8
 *
 * Change Logs:
 * Date           Author       Notes
 * 2018-05-23     ChenYong     First version
9
 * 2018-11-12     ChenYong     Add TLS support
10 11 12 13
 */

#include <rtthread.h>
#include <rthw.h>
14
#include <sys/time.h>
15 16 17

#include <sal_socket.h>
#include <sal_netdb.h>
18 19 20
#ifdef SAL_USING_TLS
#include <sal_tls.h>
#endif
21
#include <sal.h>
22
#include <netdev.h>
23

24 25 26 27 28 29 30 31 32
#include <ipc/workqueue.h>

/* check system workqueue stack size */
#if RT_SYSTEM_WORKQUEUE_STACKSIZE < 1536
#error "The system workqueue stack size must more than 1536 bytes"
#endif

#define DBG_TAG                        "sal.skt"
#define DBG_LVL                        DBG_INFO
33 34
#include <rtdbg.h>

35 36
#define SOCKET_TABLE_STEP_LEN          4

37 38 39 40 41 42 43
/* the socket table used to dynamic allocate sockets */
struct sal_socket_table
{
    uint32_t max_socket;
    struct sal_socket **sockets;
};

44 45 46 47 48
#ifdef SAL_USING_TLS
/* The global TLS protocol options */
static struct sal_proto_tls *proto_tls;
#endif

49 50 51 52 53
/* The global socket table */
static struct sal_socket_table socket_table;
static struct rt_mutex sal_core_lock;
static rt_bool_t init_ok = RT_FALSE;

54 55 56 57 58
#define IS_SOCKET_PROTO_TLS(sock)                (((sock)->protocol == PROTOCOL_TLS) || \
                                                 ((sock)->protocol == PROTOCOL_DTLS))
#define SAL_SOCKOPS_PROTO_TLS_VALID(sock, name)  (proto_tls && (proto_tls->ops->name) && IS_SOCKET_PROTO_TLS(sock))

#define SAL_SOCKOPT_PROTO_TLS_EXEC(sock, name, optval, optlen)                    \
59 60
do {                                                                              \
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, name)){                                 \
61 62 63 64
        return proto_tls->ops->name((sock)->user_data_tls, (optval), (optlen));   \
    }                                                                             \
}while(0)                                                                         \

65 66 67 68 69 70 71 72
#define SAL_SOCKET_OBJ_GET(sock, socket)                                          \
do {                                                                              \
    (sock) = sal_get_socket(socket);                                              \
    if ((sock) == RT_NULL) {                                                      \
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

73
#define SAL_NETDEV_IS_UP(netdev)                                                  \
74
do {                                                                              \
75
    if (!netdev_is_up(netdev)) {                                                  \
76 77 78 79 80 81 82 83 84 85 86 87 88
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

#define SAL_NETDEV_SOCKETOPS_VALID(netdev, pf, ops)                               \
do {                                                                              \
    (pf) = (struct sal_proto_family *) netdev->sal_user_data;                     \
    if ((pf)->skt_ops->ops == RT_NULL){                                           \
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

#define SAL_NETDEV_NETDBOPS_VALID(netdev, pf, ops)                                \
89
    ((netdev) && netdev_is_up(netdev) &&                                          \
90
    ((pf) = (struct sal_proto_family *) (netdev)->sal_user_data) != RT_NULL &&    \
91
    (pf)->netdb_ops->ops)                                                         \
92

93
/**
94
 * SAL (Socket Abstraction Layer) initialize.
95
 *
96
 * @return result  0: initialize success
97
 *                -1: initialize failed
98 99 100
 */
int sal_init(void)
{
101
    int cn;
102

103
    if (init_ok)
104 105 106 107 108
    {
        LOG_D("Socket Abstraction Layer is already initialized.");
        return 0;
    }

109 110 111 112 113 114 115 116 117
    /* init sal socket table */
    cn = SOCKET_TABLE_STEP_LEN < SAL_SOCKETS_NUM ? SOCKET_TABLE_STEP_LEN : SAL_SOCKETS_NUM;
    socket_table.max_socket = cn;
    socket_table.sockets = rt_calloc(1, cn * sizeof(struct sal_socket *));
    if (socket_table.sockets == RT_NULL)
    {
        LOG_E("No memory for socket table.\n");
        return -1;
    }
118

119 120 121 122 123 124 125 126 127 128
    /* create sal socket lock */
    rt_mutex_init(&sal_core_lock, "sal_lock", RT_IPC_FLAG_FIFO);

    LOG_I("Socket Abstraction Layer initialize success.");
    init_ok = RT_TRUE;

    return 0;
}
INIT_COMPONENT_EXPORT(sal_init);

129 130
/* check SAL network interface device internet status */
static void check_netdev_internet_up_work(struct rt_work *work, void *work_data)
131
{
132 133
#define SAL_INTERNET_VERSION   0x00
#define SAL_INTERNET_BUFF_LEN  12
134
#define SAL_INTERNET_TIMEOUT   (2)
135

136 137
#define SAL_INTERNET_HOST      "link.rt-thread.org"
#define SAL_INTERNET_PORT      8101
138

139 140
#define SAL_INTERNET_MONTH_LEN 4
#define SAL_INTERNET_DATE_LEN  16
141

142
    int index, sockfd = -1, result = 0;
143 144 145 146 147 148
    struct sockaddr_in server_addr;
    struct hostent *host;
    struct timeval timeout;
    struct netdev *netdev = (struct netdev *)work_data;
    socklen_t addr_len = sizeof(struct sockaddr_in);
    char send_data[SAL_INTERNET_BUFF_LEN], recv_data = 0;
149
    struct rt_delayed_work *delay_work = (struct rt_delayed_work *)work;
150

151
    const char month[][SAL_INTERNET_MONTH_LEN] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"};
152 153 154 155 156 157 158
    char date[SAL_INTERNET_DATE_LEN];
    int moth_num = 0;

    struct sal_proto_family *pf = (struct sal_proto_family *) netdev->sal_user_data;
    const struct sal_socket_ops *skt_ops;

    if (work)
159
    {
160
        rt_free(delay_work);
161 162
    }

163 164
    /* get network interface socket operations */
    if (pf == RT_NULL || pf->skt_ops == RT_NULL)
165
    {
166 167
        result = -RT_ERROR;
        goto __exit;
168 169
    }

170 171 172 173 174 175
    host = (struct hostent *) pf->netdb_ops->gethostbyname(SAL_INTERNET_HOST);
    if (host == RT_NULL)
    {
        result = -RT_ERROR;
        goto __exit;
    }
176

H
HubretXie 已提交
177
    skt_ops = pf->skt_ops;
178
    if ((sockfd = skt_ops->socket(AF_INET, SOCK_DGRAM, 0)) < 0)
179 180 181 182
    {
        result = -RT_ERROR;
        goto __exit;
    }
183

184 185 186 187
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(SAL_INTERNET_PORT);
    server_addr.sin_addr = *((struct in_addr *)host->h_addr);
    rt_memset(&(server_addr.sin_zero), 0, sizeof(server_addr.sin_zero));
188

189 190
    timeout.tv_sec = SAL_INTERNET_TIMEOUT;
    timeout.tv_usec = 0;
191

192 193 194
    /* set receive and send timeout */
    skt_ops->setsockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, (void *) &timeout, sizeof(timeout));
    skt_ops->setsockopt(sockfd, SOL_SOCKET, SO_SNDTIMEO, (void *) &timeout, sizeof(timeout));
195

196 197 198
    /* get build moth value*/
    rt_memset(date, 0x00, SAL_INTERNET_DATE_LEN);
    rt_snprintf(date, SAL_INTERNET_DATE_LEN, "%s", __DATE__);
199

200
    for (index = 0; index < sizeof(month) / SAL_INTERNET_MONTH_LEN; index++)
201
    {
202
        if (rt_memcmp(date, month[index], SAL_INTERNET_MONTH_LEN - 1) == 0)
203
        {
204 205
            moth_num = index + 1;
            break;
206 207 208
        }
    }

209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
    /* not find build month */
    if (moth_num == 0 || moth_num > sizeof(month) / SAL_INTERNET_MONTH_LEN)
    {
        result = -RT_ERROR;
        goto __exit;
    }

    rt_memset(send_data, 0x00, SAL_INTERNET_BUFF_LEN);
    send_data[0] = SAL_INTERNET_VERSION;
    for (index = 0; index < netdev->hwaddr_len; index++)
    {
        send_data[index + 1] = netdev->hwaddr[index] + moth_num;
    }
    send_data[9] = RT_VERSION;
    send_data[10] = RT_SUBVERSION;
    send_data[11] = RT_REVISION;

    skt_ops->sendto(sockfd, send_data, SAL_INTERNET_BUFF_LEN, 0,
227
                    (struct sockaddr *)&server_addr, sizeof(struct sockaddr));
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244

    result = skt_ops->recvfrom(sockfd, &recv_data, sizeof(recv_data), 0, (struct sockaddr *)&server_addr, &addr_len);
    if (result < 0)
    {
        goto __exit;
    }

    if (recv_data == RT_FALSE)
    {
        result = -RT_ERROR;
        goto __exit;
    }

__exit:
    if (result > 0)
    {
        LOG_D("Set network interface device(%s) internet status up.", netdev->name);
245
        netdev->flags |= NETDEV_FLAG_INTERNET_UP;
246 247 248 249 250 251 252 253 254 255 256
    }
    else
    {
        LOG_D("Set network interface device(%s) internet status down.", netdev->name);
        netdev->flags &= ~NETDEV_FLAG_INTERNET_UP;
    }

    if (sockfd >= 0)
    {
        skt_ops->closesocket(sockfd);
    }
257 258 259
}

/**
260
 * This function will check SAL network interface device internet status.
261
 *
262
 * @param netdev the network interface device to check
263
 */
264
int sal_check_netdev_internet_up(struct netdev *netdev)
265
{
266 267
    /* workqueue for network connect */
    struct rt_delayed_work *net_work = RT_NULL;
268

269
    RT_ASSERT(netdev);
270

271 272
    net_work = (struct rt_delayed_work *)rt_calloc(1, sizeof(struct rt_delayed_work));
    if (net_work == RT_NULL)
273
    {
274 275
        LOG_W("No memory for network interface device(%s) delay work.", netdev->name);
        return -1;
276 277
    }

278 279
    rt_delayed_work_init(net_work, check_netdev_internet_up_work, (void *)netdev);
    rt_work_submit(&(net_work->work), RT_TICK_PER_SECOND);
280

281
    return 0;
282 283 284
}

/**
285
 * This function will register TLS protocol to the global TLS protocol.
286
 *
287
 * @param pt TLS protocol object
288
 *
289
 * @return 0: TLS protocol object register success
290
 */
291 292
#ifdef SAL_USING_TLS
int sal_proto_tls_register(const struct sal_proto_tls *pt)
293
{
294 295
    RT_ASSERT(pt);
    proto_tls = (struct sal_proto_tls *) pt;
296

297
    return 0;
298
}
299
#endif
300 301 302

/**
 * This function will get sal socket object by sal socket descriptor.
303 304 305
 *
 * @param socket sal socket index
 *
306
 * @return sal socket object of the current sal socket index
307 308 309 310 311 312 313 314 315 316 317 318
 */
struct sal_socket *sal_get_socket(int socket)
{
    struct sal_socket_table *st = &socket_table;

    if (socket < 0 || socket >= (int) st->max_socket)
    {
        return RT_NULL;
    }

    socket = socket - SAL_SOCKET_OFFSET;
    /* check socket structure valid or not */
319
    RT_ASSERT(st->sockets[socket]->magic == SAL_SOCKET_MAGIC);
320 321 322 323 324

    return st->sockets[socket];
}

/**
325
 * This function will lock sal socket.
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
 *
 * @note please don't invoke it on ISR.
 */
static void sal_lock(void)
{
    rt_err_t result;

    result = rt_mutex_take(&sal_core_lock, RT_WAITING_FOREVER);
    if (result != RT_EOK)
    {
        RT_ASSERT(0);
    }
}

/**
341
 * This function will lock sal socket.
342 343 344 345 346 347 348 349
 *
 * @note please don't invoke it on ISR.
 */
static void sal_unlock(void)
{
    rt_mutex_release(&sal_core_lock);
}

350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
/**
 * This function will clean the netdev.
 *
 * @note please don't invoke it on ISR.
 */
int sal_netdev_cleanup(struct netdev *netdev)
{
    int idx = 0, find_dev;

    do
    {
        find_dev = 0;
        sal_lock();
        for (idx = 0; idx < socket_table.max_socket; idx++)
        {
            if (socket_table.sockets[idx] && socket_table.sockets[idx]->netdev == netdev)
            {
                find_dev = 1;
                break;
            }
        }
        sal_unlock();
        if (find_dev)
        {
374
            rt_thread_mdelay(100);
375
        }
376 377
    }
    while (find_dev);
378 379 380 381

    return 0;
}

382
/**
383
 * This function will initialize sal socket object and set socket options
384 385 386 387
 *
 * @param family    protocol family
 * @param type      socket type
 * @param protocol  transfer Protocol
388
 * @param res       sal socket object address
389 390 391 392
 *
 * @return  0 : socket initialize success
 *         -1 : input the wrong family
 *         -2 : input the wrong socket type
393
 *         -3 : get network interface failed
394 395 396
 */
static int socket_init(int family, int type, int protocol, struct sal_socket **res)
{
397

398
    struct sal_socket *sock;
399
    struct sal_proto_family *pf;
400 401
    struct netdev *netdv_def = netdev_default;
    struct netdev *netdev = RT_NULL;
402
    rt_bool_t flag = RT_FALSE;
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418

    if (family < 0 || family > AF_MAX)
    {
        return -1;
    }

    if (type < 0 || type > SOCK_MAX)
    {
        return -2;
    }

    sock = *res;
    sock->domain = family;
    sock->type = type;
    sock->protocol = protocol;

419
    if (netdv_def && netdev_is_up(netdv_def))
420
    {
421 422 423
        /* check default network interface device protocol family */
        pf = (struct sal_proto_family *) netdv_def->sal_user_data;
        if (pf != RT_NULL && pf->skt_ops && (pf->family == family || pf->sec_family == family))
424
        {
425
            sock->netdev = netdv_def;
426
            flag = RT_TRUE;
427 428
        }
    }
429

430
    if (flag == RT_FALSE)
431
    {
432 433 434 435 436 437 438 439 440
        /* get network interface device by protocol family */
        netdev = netdev_get_by_family(family);
        if (netdev == RT_NULL)
        {
            LOG_E("not find network interface device by protocol family(%d).", family);
            return -3;
        }

        sock->netdev = netdev;
441 442 443 444 445 446 447 448 449 450 451 452
    }

    return 0;
}

static int socket_alloc(struct sal_socket_table *st, int f_socket)
{
    int idx;

    /* find an empty socket entry */
    for (idx = f_socket; idx < (int) st->max_socket; idx++)
    {
453
        if (st->sockets[idx] == RT_NULL)
454
        {
455
            break;
456
        }
457 458 459 460 461 462 463 464
    }

    /* allocate a larger sockte container */
    if (idx == (int) st->max_socket &&  st->max_socket < SAL_SOCKETS_NUM)
    {
        int cnt, index;
        struct sal_socket **sockets;

465 466
        /* increase the number of socket with 4 step length */
        cnt = st->max_socket + SOCKET_TABLE_STEP_LEN;
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
        cnt = cnt > SAL_SOCKETS_NUM ? SAL_SOCKETS_NUM : cnt;

        sockets = rt_realloc(st->sockets, cnt * sizeof(struct sal_socket *));
        if (sockets == RT_NULL)
            goto __result; /* return st->max_socket */

        /* clean the new allocated fds */
        for (index = st->max_socket; index < cnt; index++)
        {
            sockets[index] = RT_NULL;
        }

        st->sockets = sockets;
        st->max_socket = cnt;
    }

    /* allocate  'struct sal_socket' */
    if (idx < (int) st->max_socket && st->sockets[idx] == RT_NULL)
    {
486
        st->sockets[idx] = rt_calloc(1, sizeof(struct sal_socket));
487 488 489 490 491 492 493 494 495 496
        if (st->sockets[idx] == RT_NULL)
        {
            idx = st->max_socket;
        }
    }

__result:
    return idx;
}

497 498 499 500 501 502 503 504 505
static void socket_free(struct sal_socket_table *st, int idx)
{
    struct sal_socket *sock;

    sock = st->sockets[idx];
    st->sockets[idx] = RT_NULL;
    rt_free(sock);
}

506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
static int socket_new(void)
{
    struct sal_socket *sock;
    struct sal_socket_table *st = &socket_table;
    int idx;

    sal_lock();

    /* find an empty sal socket entry */
    idx = socket_alloc(st, 0);

    /* can't find an empty sal socket entry */
    if (idx == (int) st->max_socket)
    {
        idx = -(1 + SAL_SOCKET_OFFSET);
        goto __result;
    }

    sock = st->sockets[idx];
    sock->socket = idx + SAL_SOCKET_OFFSET;
    sock->magic = SAL_SOCKET_MAGIC;
527
    sock->netdev = RT_NULL;
528 529 530 531
    sock->user_data = RT_NULL;
#ifdef SAL_USING_TLS
    sock->user_data_tls = RT_NULL;
#endif
532 533 534 535 536 537

__result:
    sal_unlock();
    return idx + SAL_SOCKET_OFFSET;
}

538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
static void socket_delete(int socket)
{
    struct sal_socket *sock;
    struct sal_socket_table *st = &socket_table;
    int idx;

    idx = socket - SAL_SOCKET_OFFSET;
    if (idx < 0 || idx >= (int) st->max_socket)
    {
        return;
    }
    sal_lock();
    sock = sal_get_socket(socket);
    RT_ASSERT(sock != RT_NULL);
    sock->magic = 0;
    sock->netdev = RT_NULL;
    socket_free(st, idx);
    sal_unlock();
}

558 559 560 561
int sal_accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
{
    int new_socket;
    struct sal_socket *sock;
562
    struct sal_proto_family *pf;
563

564 565
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
566

567 568
    /* check the network interface socket operations */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, accept);
569

570
    new_socket = pf->skt_ops->accept((int) sock->user_data, addr, addrlen);
571 572 573
    if (new_socket != -1)
    {
        int retval;
574
        int new_sal_socket;
575 576 577
        struct sal_socket *new_sock;

        /* allocate a new socket structure and registered socket options */
578
        new_sal_socket = socket_new();
579 580
        new_sock = sal_get_socket(new_sal_socket);
        if (new_sock == RT_NULL)
581
        {
582
            pf->skt_ops->closesocket(new_socket);
583 584 585 586 587 588
            return -1;
        }

        retval = socket_init(sock->domain, sock->type, sock->protocol, &new_sock);
        if (retval < 0)
        {
589
            pf->skt_ops->closesocket(new_socket);
590
            rt_memset(new_sock, 0x00, sizeof(struct sal_socket));
591 592
            /* socket init failed, delete socket */
            socket_delete(new_sal_socket);
593 594 595 596
            LOG_E("New socket registered failed, return error %d.", retval);
            return -1;
        }

597
        /* socket structure user_data used to store the acquired new socket */
598 599
        new_sock->user_data = (void *) new_socket;

600
        return new_sal_socket;
601 602 603 604 605
    }

    return -1;
}

606 607 608 609 610
static void sal_sockaddr_to_ipaddr(const struct sockaddr *name, ip_addr_t *local_ipaddr)
{
    const struct sockaddr_in *svr_addr = (const struct sockaddr_in *) name;

#if NETDEV_IPV4 && NETDEV_IPV6
611 612
    local_ipaddr->u_addr.ip4.addr = svr_addr->sin_addr.s_addr;
    local_ipaddr->type = IPADDR_TYPE_V4;
613
#elif NETDEV_IPV4
614
    local_ipaddr->addr = svr_addr->sin_addr.s_addr;
615
#elif NETDEV_IPV6
616 617
#error "not only support IPV6"
#endif /* NETDEV_IPV4 && NETDEV_IPV6*/
618
}
619

620 621 622
int sal_bind(int socket, const struct sockaddr *name, socklen_t namelen)
{
    struct sal_socket *sock;
623
    struct sal_proto_family *pf;
624
    ip_addr_t input_ipaddr;
625

626 627 628 629 630 631
    RT_ASSERT(name);

    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);

    /* bind network interface by ip address */
632
    sal_sockaddr_to_ipaddr(name, &input_ipaddr);
633 634

    /* check input ipaddr is default netdev ipaddr */
635
    if (!ip_addr_isany_val(input_ipaddr))
636
    {
637 638
        struct sal_proto_family *input_pf = RT_NULL, *local_pf = RT_NULL;
        struct netdev *new_netdev = RT_NULL;
639

640 641 642 643 644
        new_netdev = netdev_get_by_ipaddr(&input_ipaddr);
        if (new_netdev == RT_NULL)
        {
            return -1;
        }
645

646 647 648
        /* get input and local ip address proto_family */
        SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, local_pf, bind);
        SAL_NETDEV_SOCKETOPS_VALID(new_netdev, input_pf, bind);
649

650 651
        /* check the network interface protocol family type */
        if (input_pf->family != local_pf->family)
652
        {
653 654 655 656 657 658 659 660 661 662 663 664
            int new_socket = -1;

            /* protocol family is different, close old socket and create new socket by input ip address */
            local_pf->skt_ops->closesocket(socket);

            new_socket = input_pf->skt_ops->socket(input_pf->family, sock->type, sock->protocol);
            if (new_socket < 0)
            {
                return -1;
            }
            sock->netdev = new_netdev;
            sock->user_data = (void *) new_socket;
665
        }
666
    }
667

668
    /* check and get protocol families by the network interface device */
669 670
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, bind);
    return pf->skt_ops->bind((int) sock->user_data, name, namelen);
671 672 673 674 675
}

int sal_shutdown(int socket, int how)
{
    struct sal_socket *sock;
676 677
    struct sal_proto_family *pf;
    int error = 0;
678

679 680
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
681

L
luhuadong 已提交
682
    /* shutdown operation not need to check network interface status */
683 684
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, shutdown);
685

686
    if (pf->skt_ops->shutdown((int) sock->user_data, how) == 0)
687
    {
688 689 690 691 692 693 694 695 696
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, closesocket))
        {
            if (proto_tls->ops->closesocket(sock->user_data_tls) < 0)
            {
                return -1;
            }
        }
#endif
697 698 699 700 701
        error = 0;
    }
    else
    {
        error = -1;
702 703
    }

704 705

    return error;
706 707 708 709 710
}

int sal_getpeername(int socket, struct sockaddr *name, socklen_t *namelen)
{
    struct sal_socket *sock;
711
    struct sal_proto_family *pf;
712

713 714
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
715

716 717
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getpeername);
718

719
    return pf->skt_ops->getpeername((int) sock->user_data, name, namelen);
720 721 722 723 724
}

int sal_getsockname(int socket, struct sockaddr *name, socklen_t *namelen)
{
    struct sal_socket *sock;
725
    struct sal_proto_family *pf;
726

727 728
    /* get socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
729

730 731
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getsockname);
732

733
    return pf->skt_ops->getsockname((int) sock->user_data, name, namelen);
734 735 736 737 738
}

int sal_getsockopt(int socket, int level, int optname, void *optval, socklen_t *optlen)
{
    struct sal_socket *sock;
739
    struct sal_proto_family *pf;
740

741 742
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
743

744 745
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getsockopt);
746

747
    return pf->skt_ops->getsockopt((int) sock->user_data, level, optname, optval, optlen);
748 749 750 751 752
}

int sal_setsockopt(int socket, int level, int optname, const void *optval, socklen_t optlen)
{
    struct sal_socket *sock;
753
    struct sal_proto_family *pf;
754

755 756
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
757

758 759
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, setsockopt);
760

761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789
#ifdef SAL_USING_TLS
    if (level == SOL_TLS)
    {
        switch (optname)
        {
        case TLS_CRET_LIST:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_cret_list, optval, optlen);
            break;

        case TLS_CIPHERSUITE_LIST:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_ciphersurite, optval, optlen);
            break;

        case TLS_PEER_VERIFY:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_peer_verify, optval, optlen);
            break;

        case TLS_DTLS_ROLE:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_dtls_role, optval, optlen);
            break;

        default:
            return -1;
        }

        return 0;
    }
    else
    {
790
        return pf->skt_ops->setsockopt((int) sock->user_data, level, optname, optval, optlen);
791 792
    }
#else
793
    return pf->skt_ops->setsockopt((int) sock->user_data, level, optname, optval, optlen);
794
#endif /* SAL_USING_TLS */
795 796 797 798 799
}

int sal_connect(int socket, const struct sockaddr *name, socklen_t namelen)
{
    struct sal_socket *sock;
800
    struct sal_proto_family *pf;
801
    int ret;
802

803 804
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
805

806 807
    /* check the network interface is up status */
    SAL_NETDEV_IS_UP(sock->netdev);
808 809
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, connect);
810

811
    ret = pf->skt_ops->connect((int) sock->user_data, name, namelen);
812 813 814 815 816 817 818
#ifdef SAL_USING_TLS
    if (ret >= 0 && SAL_SOCKOPS_PROTO_TLS_VALID(sock, connect))
    {
        if (proto_tls->ops->connect(sock->user_data_tls) < 0)
        {
            return -1;
        }
819

820 821 822 823 824
        return ret;
    }
#endif

    return ret;
825 826 827 828 829
}

int sal_listen(int socket, int backlog)
{
    struct sal_socket *sock;
830
    struct sal_proto_family *pf;
831

832 833
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
834

835 836
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, listen);
837

838
    return pf->skt_ops->listen((int) sock->user_data, backlog);
839 840 841
}

int sal_recvfrom(int socket, void *mem, size_t len, int flags,
842
                 struct sockaddr *from, socklen_t *fromlen)
843 844
{
    struct sal_socket *sock;
845
    struct sal_proto_family *pf;
846

847 848
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
849

850 851
    /* check the network interface is up status  */
    SAL_NETDEV_IS_UP(sock->netdev);
852 853
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, recvfrom);
854

855 856 857 858
#ifdef SAL_USING_TLS
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, recv))
    {
        int ret;
859

860 861 862
        if ((ret = proto_tls->ops->recv(sock->user_data_tls, mem, len)) < 0)
        {
            return -1;
863
        }
864 865 866 867
        return ret;
    }
    else
    {
868
        return pf->skt_ops->recvfrom((int) sock->user_data, mem, len, flags, from, fromlen);
869 870
    }
#else
871
    return pf->skt_ops->recvfrom((int) sock->user_data, mem, len, flags, from, fromlen);
872
#endif
873 874 875
}

int sal_sendto(int socket, const void *dataptr, size_t size, int flags,
876
               const struct sockaddr *to, socklen_t tolen)
877 878
{
    struct sal_socket *sock;
879
    struct sal_proto_family *pf;
880

881 882
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
883

884 885
    /* check the network interface is up status  */
    SAL_NETDEV_IS_UP(sock->netdev);
886 887
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, sendto);
888

889 890 891 892
#ifdef SAL_USING_TLS
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, send))
    {
        int ret;
893

894 895 896
        if ((ret = proto_tls->ops->send(sock->user_data_tls, dataptr, size)) < 0)
        {
            return -1;
897
        }
898 899 900 901
        return ret;
    }
    else
    {
902
        return pf->skt_ops->sendto((int) sock->user_data, dataptr, size, flags, to, tolen);
903 904
    }
#else
905
    return pf->skt_ops->sendto((int) sock->user_data, dataptr, size, flags, to, tolen);
906
#endif
907 908 909 910 911 912 913
}

int sal_socket(int domain, int type, int protocol)
{
    int retval;
    int socket, proto_socket;
    struct sal_socket *sock;
914
    struct sal_proto_family *pf;
915 916 917 918 919 920 921

    /* allocate a new socket and registered socket options */
    socket = socket_new();
    if (socket < 0)
    {
        return -1;
    }
922 923

    /* get sal socket object by socket descriptor */
924
    sock = sal_get_socket(socket);
925 926
    if (sock == RT_NULL)
    {
927
        socket_delete(socket);
928 929
        return -1;
    }
930

931
    /* Initialize sal socket object */
932 933 934 935
    retval = socket_init(domain, type, protocol, &sock);
    if (retval < 0)
    {
        LOG_E("SAL socket protocol family input failed, return error %d.", retval);
936
        socket_delete(socket);
937 938 939
        return -1;
    }

940 941
    /* valid the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, socket);
942

943
    proto_socket = pf->skt_ops->socket(domain, type, protocol);
944 945
    if (proto_socket >= 0)
    {
946 947 948
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, socket))
        {
949
            sock->user_data_tls = proto_tls->ops->socket(socket);
950 951
            if (sock->user_data_tls == RT_NULL)
            {
952
                socket_delete(socket);
953 954 955 956
                return -1;
            }
        }
#endif
957 958 959
        sock->user_data = (void *) proto_socket;
        return sock->socket;
    }
960
    socket_delete(socket);
961 962 963 964 965 966
    return -1;
}

int sal_closesocket(int socket)
{
    struct sal_socket *sock;
967 968
    struct sal_proto_family *pf;
    int error = 0;
969

970 971
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
972

L
luhuadong 已提交
973
    /* clsoesocket operation not need to vaild network interface status */
974 975
    /* valid the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, socket);
976

977
    if (pf->skt_ops->closesocket((int) sock->user_data) == 0)
978
    {
979 980 981 982 983 984 985 986 987
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, closesocket))
        {
            if (proto_tls->ops->closesocket(sock->user_data_tls) < 0)
            {
                return -1;
            }
        }
#endif
988 989 990 991 992
        error = 0;
    }
    else
    {
        error = -1;
993 994
    }

995 996
    /* delete socket */
    socket_delete(socket);
997 998

    return error;
999 1000 1001 1002 1003
}

int sal_ioctlsocket(int socket, long cmd, void *arg)
{
    struct sal_socket *sock;
1004
    struct sal_proto_family *pf;
1005

1006 1007
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
1008

1009 1010
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, ioctlsocket);
1011

1012
    return pf->skt_ops->ioctlsocket((int) sock->user_data, cmd, arg);
1013 1014
}

1015
#ifdef SAL_USING_POSIX
1016 1017 1018
int sal_poll(struct dfs_fd *file, struct rt_pollreq *req)
{
    struct sal_socket *sock;
1019
    struct sal_proto_family *pf;
1020 1021
    int socket = (int) file->data;

1022 1023
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
1024

1025 1026
    /* check the network interface is up status  */
    SAL_NETDEV_IS_UP(sock->netdev);
1027 1028
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, poll);
1029

1030
    return pf->skt_ops->poll(file, req);
1031
}
1032
#endif
1033 1034 1035

struct hostent *sal_gethostbyname(const char *name)
{
1036 1037
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1038

1039
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname))
1040
    {
1041 1042 1043 1044
        return pf->netdb_ops->gethostbyname(name);
    }
    else
    {
1045 1046
        /* get the first network interface device with up status */
        netdev = netdev_get_first_by_flags(NETDEV_FLAG_UP);
1047
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname))
1048
        {
1049
            return pf->netdb_ops->gethostbyname(name);
1050 1051 1052 1053 1054 1055 1056
        }
    }

    return RT_NULL;
}

int sal_gethostbyname_r(const char *name, struct hostent *ret, char *buf,
1057
                        size_t buflen, struct hostent **result, int *h_errnop)
1058
{
1059 1060
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1061

1062
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname_r))
1063
    {
1064 1065 1066 1067
        return pf->netdb_ops->gethostbyname_r(name, ret, buf, buflen, result, h_errnop);
    }
    else
    {
1068 1069
        /* get the first network interface device with up status */
        netdev = netdev_get_first_by_flags(NETDEV_FLAG_UP);
1070
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname_r))
1071
        {
1072
            return pf->netdb_ops->gethostbyname_r(name, ret, buf, buflen, result, h_errnop);
1073 1074 1075 1076 1077 1078 1079
        }
    }

    return -1;
}

int sal_getaddrinfo(const char *nodename,
1080 1081 1082
                    const char *servname,
                    const struct addrinfo *hints,
                    struct addrinfo **res)
1083
{
1084 1085
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1086

1087
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, getaddrinfo))
1088
    {
1089 1090 1091 1092
        return pf->netdb_ops->getaddrinfo(nodename, servname, hints, res);
    }
    else
    {
1093 1094
        /* get the first network interface device with up status */
        netdev = netdev_get_first_by_flags(NETDEV_FLAG_UP);
1095
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, getaddrinfo))
1096
        {
1097
            return pf->netdb_ops->getaddrinfo(nodename, servname, hints, res);
1098 1099 1100 1101 1102
        }
    }

    return -1;
}
1103 1104 1105

void sal_freeaddrinfo(struct addrinfo *ai)
{
1106 1107
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1108

1109
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, freeaddrinfo))
1110
    {
1111 1112 1113 1114
        pf->netdb_ops->freeaddrinfo(ai);
    }
    else
    {
1115 1116
        /* get the first network interface device with up status */
        netdev = netdev_get_first_by_flags(NETDEV_FLAG_UP);
1117
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, freeaddrinfo))
1118
        {
1119
            pf->netdb_ops->freeaddrinfo(ai);
1120 1121 1122
        }
    }
}
1123