sal_socket.c 34.3 KB
Newer Older
1
/*
dongly's avatar
dongly 已提交
2
 * Copyright (c) 2006-2022, RT-Thread Development Team
3
 *
4
 * SPDX-License-Identifier: Apache-2.0
5 6 7 8
 *
 * Change Logs:
 * Date           Author       Notes
 * 2018-05-23     ChenYong     First version
9
 * 2018-11-12     ChenYong     Add TLS support
10 11 12 13
 */

#include <rtthread.h>
#include <rthw.h>
14
#include <sys/time.h>
G
guo 已提交
15
#include <sys/ioctl.h>
16 17 18

#include <sal_socket.h>
#include <sal_netdb.h>
19 20 21
#ifdef SAL_USING_TLS
#include <sal_tls.h>
#endif
mysterywolf's avatar
mysterywolf 已提交
22
#include <sal_low_lvl.h>
23
#include <netdev.h>
24

25
#ifdef SAL_INTERNET_CHECK
26
#include <ipc/workqueue.h>
27
#endif
28 29 30 31 32 33 34 35

/* check system workqueue stack size */
#if RT_SYSTEM_WORKQUEUE_STACKSIZE < 1536
#error "The system workqueue stack size must more than 1536 bytes"
#endif

#define DBG_TAG                        "sal.skt"
#define DBG_LVL                        DBG_INFO
36 37
#include <rtdbg.h>

38 39
#define SOCKET_TABLE_STEP_LEN          4

40 41 42 43 44 45 46
/* the socket table used to dynamic allocate sockets */
struct sal_socket_table
{
    uint32_t max_socket;
    struct sal_socket **sockets;
};

47 48 49 50
/* record the netdev and res table*/
struct sal_netdev_res_table
{
    struct addrinfo *res;
51
    struct netdev *netdev;
52 53
};

54 55 56 57 58
#ifdef SAL_USING_TLS
/* The global TLS protocol options */
static struct sal_proto_tls *proto_tls;
#endif

59 60 61 62
/* The global socket table */
static struct sal_socket_table socket_table;
static struct rt_mutex sal_core_lock;
static rt_bool_t init_ok = RT_FALSE;
63
static struct sal_netdev_res_table sal_dev_res_tbl[SAL_SOCKETS_NUM];
64

65 66 67 68 69
#define IS_SOCKET_PROTO_TLS(sock)                (((sock)->protocol == PROTOCOL_TLS) || \
                                                 ((sock)->protocol == PROTOCOL_DTLS))
#define SAL_SOCKOPS_PROTO_TLS_VALID(sock, name)  (proto_tls && (proto_tls->ops->name) && IS_SOCKET_PROTO_TLS(sock))

#define SAL_SOCKOPT_PROTO_TLS_EXEC(sock, name, optval, optlen)                    \
70 71
do {                                                                              \
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, name)){                                 \
72 73 74 75
        return proto_tls->ops->name((sock)->user_data_tls, (optval), (optlen));   \
    }                                                                             \
}while(0)                                                                         \

76 77 78 79 80 81 82 83
#define SAL_SOCKET_OBJ_GET(sock, socket)                                          \
do {                                                                              \
    (sock) = sal_get_socket(socket);                                              \
    if ((sock) == RT_NULL) {                                                      \
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

84
#define SAL_NETDEV_IS_UP(netdev)                                                  \
85
do {                                                                              \
86
    if (!netdev_is_up(netdev)) {                                                  \
87 88 89 90 91 92 93 94 95 96 97 98 99
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

#define SAL_NETDEV_SOCKETOPS_VALID(netdev, pf, ops)                               \
do {                                                                              \
    (pf) = (struct sal_proto_family *) netdev->sal_user_data;                     \
    if ((pf)->skt_ops->ops == RT_NULL){                                           \
        return -1;                                                                \
    }                                                                             \
}while(0)                                                                         \

#define SAL_NETDEV_NETDBOPS_VALID(netdev, pf, ops)                                \
100
    ((netdev) && netdev_is_up(netdev) &&                                          \
101
    ((pf) = (struct sal_proto_family *) (netdev)->sal_user_data) != RT_NULL &&    \
102
    (pf)->netdb_ops->ops)                                                         \
103

104 105 106 107 108
#define SAL_NETDBOPS_VALID(netdev, pf, ops)                                \
    ((netdev) &&                                                                 \
    ((pf) = (struct sal_proto_family *) (netdev)->sal_user_data) != RT_NULL &&    \
    (pf)->netdb_ops->ops)                                                         \

109
/**
110
 * SAL (Socket Abstraction Layer) initialize.
111
 *
112
 * @return result  0: initialize success
113
 *                -1: initialize failed
114 115 116
 */
int sal_init(void)
{
117
    int cn;
118

119
    if (init_ok)
120 121 122 123 124
    {
        LOG_D("Socket Abstraction Layer is already initialized.");
        return 0;
    }

125 126 127 128 129 130 131 132 133
    /* init sal socket table */
    cn = SOCKET_TABLE_STEP_LEN < SAL_SOCKETS_NUM ? SOCKET_TABLE_STEP_LEN : SAL_SOCKETS_NUM;
    socket_table.max_socket = cn;
    socket_table.sockets = rt_calloc(1, cn * sizeof(struct sal_socket *));
    if (socket_table.sockets == RT_NULL)
    {
        LOG_E("No memory for socket table.\n");
        return -1;
    }
134

135 136
    /*init the dev_res table */
    rt_memset(sal_dev_res_tbl,  0, sizeof(sal_dev_res_tbl));
137

138
    /* create sal socket lock */
139
    rt_mutex_init(&sal_core_lock, "sal_lock", RT_IPC_FLAG_PRIO);
140 141 142 143 144 145 146 147

    LOG_I("Socket Abstraction Layer initialize success.");
    init_ok = RT_TRUE;

    return 0;
}
INIT_COMPONENT_EXPORT(sal_init);

148
#ifdef SAL_INTERNET_CHECK
149 150
/* check SAL network interface device internet status */
static void check_netdev_internet_up_work(struct rt_work *work, void *work_data)
151
{
152 153
#define SAL_INTERNET_VERSION   0x00
#define SAL_INTERNET_BUFF_LEN  12
154
#define SAL_INTERNET_TIMEOUT   (2)
155

156 157
#define SAL_INTERNET_HOST      "link.rt-thread.org"
#define SAL_INTERNET_PORT      8101
158

159 160
#define SAL_INTERNET_MONTH_LEN 4
#define SAL_INTERNET_DATE_LEN  16
161

dongly's avatar
dongly 已提交
162 163
    unsigned int index;
    int sockfd = -1, result = 0;
164 165 166 167 168 169
    struct sockaddr_in server_addr;
    struct hostent *host;
    struct timeval timeout;
    struct netdev *netdev = (struct netdev *)work_data;
    socklen_t addr_len = sizeof(struct sockaddr_in);
    char send_data[SAL_INTERNET_BUFF_LEN], recv_data = 0;
170

171
    const char month[][SAL_INTERNET_MONTH_LEN] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"};
172
    char date[SAL_INTERNET_DATE_LEN];
dongly's avatar
dongly 已提交
173
    unsigned int moth_num = 0;
174 175 176 177 178

    struct sal_proto_family *pf = (struct sal_proto_family *) netdev->sal_user_data;
    const struct sal_socket_ops *skt_ops;

    if (work)
179
    {
180
        rt_free(work);
181 182
    }

183 184
    /* get network interface socket operations */
    if (pf == RT_NULL || pf->skt_ops == RT_NULL)
185
    {
186 187
        result = -RT_ERROR;
        goto __exit;
188 189
    }

190 191 192 193 194 195
    host = (struct hostent *) pf->netdb_ops->gethostbyname(SAL_INTERNET_HOST);
    if (host == RT_NULL)
    {
        result = -RT_ERROR;
        goto __exit;
    }
196

H
HubretXie 已提交
197
    skt_ops = pf->skt_ops;
198
    if ((sockfd = skt_ops->socket(AF_INET, SOCK_DGRAM, 0)) < 0)
199 200 201 202
    {
        result = -RT_ERROR;
        goto __exit;
    }
203

204 205 206 207
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(SAL_INTERNET_PORT);
    server_addr.sin_addr = *((struct in_addr *)host->h_addr);
    rt_memset(&(server_addr.sin_zero), 0, sizeof(server_addr.sin_zero));
208

209 210
    timeout.tv_sec = SAL_INTERNET_TIMEOUT;
    timeout.tv_usec = 0;
211

212 213 214
    /* set receive and send timeout */
    skt_ops->setsockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, (void *) &timeout, sizeof(timeout));
    skt_ops->setsockopt(sockfd, SOL_SOCKET, SO_SNDTIMEO, (void *) &timeout, sizeof(timeout));
215

216 217 218
    /* get build moth value*/
    rt_memset(date, 0x00, SAL_INTERNET_DATE_LEN);
    rt_snprintf(date, SAL_INTERNET_DATE_LEN, "%s", __DATE__);
219

220
    for (index = 0; index < sizeof(month) / SAL_INTERNET_MONTH_LEN; index++)
221
    {
222
        if (rt_memcmp(date, month[index], SAL_INTERNET_MONTH_LEN - 1) == 0)
223
        {
224 225
            moth_num = index + 1;
            break;
226 227 228
        }
    }

229 230 231 232 233 234 235 236 237 238 239 240 241
    /* not find build month */
    if (moth_num == 0 || moth_num > sizeof(month) / SAL_INTERNET_MONTH_LEN)
    {
        result = -RT_ERROR;
        goto __exit;
    }

    rt_memset(send_data, 0x00, SAL_INTERNET_BUFF_LEN);
    send_data[0] = SAL_INTERNET_VERSION;
    for (index = 0; index < netdev->hwaddr_len; index++)
    {
        send_data[index + 1] = netdev->hwaddr[index] + moth_num;
    }
242 243 244
    send_data[9] = RT_VERSION_MAJOR;
    send_data[10] = RT_VERSION_MINOR;
    send_data[11] = RT_VERSION_PATCH;
245 246

    skt_ops->sendto(sockfd, send_data, SAL_INTERNET_BUFF_LEN, 0,
247
                    (struct sockaddr *)&server_addr, sizeof(struct sockaddr));
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264

    result = skt_ops->recvfrom(sockfd, &recv_data, sizeof(recv_data), 0, (struct sockaddr *)&server_addr, &addr_len);
    if (result < 0)
    {
        goto __exit;
    }

    if (recv_data == RT_FALSE)
    {
        result = -RT_ERROR;
        goto __exit;
    }

__exit:
    if (result > 0)
    {
        LOG_D("Set network interface device(%s) internet status up.", netdev->name);
265
        netdev_low_level_set_internet_status(netdev, RT_TRUE);
266 267 268 269
    }
    else
    {
        LOG_D("Set network interface device(%s) internet status down.", netdev->name);
270
        netdev_low_level_set_internet_status(netdev, RT_FALSE);
271 272 273 274 275 276
    }

    if (sockfd >= 0)
    {
        skt_ops->closesocket(sockfd);
    }
277
}
278
#endif /* SAL_INTERNET_CHECK */
279 280

/**
281
 * This function will check SAL network interface device internet status.
282
 *
283
 * @param netdev the network interface device to check
284
 */
285
int sal_check_netdev_internet_up(struct netdev *netdev)
286
{
287 288 289
    RT_ASSERT(netdev);

#ifdef SAL_INTERNET_CHECK
290
    /* workqueue for network connect */
291
    struct rt_work *net_work = RT_NULL;
292 293


294
    net_work = (struct rt_work *)rt_calloc(1, sizeof(struct rt_work));
295
    if (net_work == RT_NULL)
296
    {
297 298
        LOG_W("No memory for network interface device(%s) delay work.", netdev->name);
        return -1;
299 300
    }

301 302
    rt_work_init(net_work, check_netdev_internet_up_work, (void *)netdev);
    rt_work_submit(net_work, RT_TICK_PER_SECOND);
303
#endif /* SAL_INTERNET_CHECK */
304
    return 0;
305 306 307
}

/**
308
 * This function will register TLS protocol to the global TLS protocol.
309
 *
310
 * @param pt TLS protocol object
311
 *
312
 * @return 0: TLS protocol object register success
313
 */
314 315
#ifdef SAL_USING_TLS
int sal_proto_tls_register(const struct sal_proto_tls *pt)
316
{
317 318
    RT_ASSERT(pt);
    proto_tls = (struct sal_proto_tls *) pt;
319

320
    return 0;
321
}
322
#endif
323 324 325

/**
 * This function will get sal socket object by sal socket descriptor.
326 327 328
 *
 * @param socket sal socket index
 *
329
 * @return sal socket object of the current sal socket index
330 331 332 333 334
 */
struct sal_socket *sal_get_socket(int socket)
{
    struct sal_socket_table *st = &socket_table;

C
caixf 已提交
335 336
    socket = socket - SAL_SOCKET_OFFSET;

337 338 339 340 341 342
    if (socket < 0 || socket >= (int) st->max_socket)
    {
        return RT_NULL;
    }

    /* check socket structure valid or not */
343
    RT_ASSERT(st->sockets[socket]->magic == SAL_SOCKET_MAGIC);
344 345 346 347 348

    return st->sockets[socket];
}

/**
349
 * This function will lock sal socket.
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
 *
 * @note please don't invoke it on ISR.
 */
static void sal_lock(void)
{
    rt_err_t result;

    result = rt_mutex_take(&sal_core_lock, RT_WAITING_FOREVER);
    if (result != RT_EOK)
    {
        RT_ASSERT(0);
    }
}

/**
365
 * This function will lock sal socket.
366 367 368 369 370 371 372 373
 *
 * @note please don't invoke it on ISR.
 */
static void sal_unlock(void)
{
    rt_mutex_release(&sal_core_lock);
}

374 375 376 377 378 379 380
/**
 * This function will clean the netdev.
 *
 * @note please don't invoke it on ISR.
 */
int sal_netdev_cleanup(struct netdev *netdev)
{
还_没_想_好's avatar
还_没_想_好 已提交
381 382
    uint32_t idx = 0;
    int find_dev;
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398

    do
    {
        find_dev = 0;
        sal_lock();
        for (idx = 0; idx < socket_table.max_socket; idx++)
        {
            if (socket_table.sockets[idx] && socket_table.sockets[idx]->netdev == netdev)
            {
                find_dev = 1;
                break;
            }
        }
        sal_unlock();
        if (find_dev)
        {
399
            rt_thread_mdelay(100);
400
        }
401 402
    }
    while (find_dev);
403 404 405 406

    return 0;
}

407
/**
408
 * This function will initialize sal socket object and set socket options
409 410 411 412
 *
 * @param family    protocol family
 * @param type      socket type
 * @param protocol  transfer Protocol
413
 * @param res       sal socket object address
414 415 416 417
 *
 * @return  0 : socket initialize success
 *         -1 : input the wrong family
 *         -2 : input the wrong socket type
418
 *         -3 : get network interface failed
419 420 421
 */
static int socket_init(int family, int type, int protocol, struct sal_socket **res)
{
422

423
    struct sal_socket *sock;
424
    struct sal_proto_family *pf;
425 426
    struct netdev *netdv_def = netdev_default;
    struct netdev *netdev = RT_NULL;
427
    rt_bool_t flag = RT_FALSE;
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443

    if (family < 0 || family > AF_MAX)
    {
        return -1;
    }

    if (type < 0 || type > SOCK_MAX)
    {
        return -2;
    }

    sock = *res;
    sock->domain = family;
    sock->type = type;
    sock->protocol = protocol;

444
    if (netdv_def && netdev_is_up(netdv_def))
445
    {
446 447 448
        /* check default network interface device protocol family */
        pf = (struct sal_proto_family *) netdv_def->sal_user_data;
        if (pf != RT_NULL && pf->skt_ops && (pf->family == family || pf->sec_family == family))
449
        {
450
            sock->netdev = netdv_def;
451
            flag = RT_TRUE;
452 453
        }
    }
454

455
    if (flag == RT_FALSE)
456
    {
457 458 459 460 461 462 463 464 465
        /* get network interface device by protocol family */
        netdev = netdev_get_by_family(family);
        if (netdev == RT_NULL)
        {
            LOG_E("not find network interface device by protocol family(%d).", family);
            return -3;
        }

        sock->netdev = netdev;
466 467 468 469 470 471 472 473 474 475 476 477
    }

    return 0;
}

static int socket_alloc(struct sal_socket_table *st, int f_socket)
{
    int idx;

    /* find an empty socket entry */
    for (idx = f_socket; idx < (int) st->max_socket; idx++)
    {
478
        if (st->sockets[idx] == RT_NULL)
479
        {
480
            break;
481
        }
482 483 484 485 486 487 488 489
    }

    /* allocate a larger sockte container */
    if (idx == (int) st->max_socket &&  st->max_socket < SAL_SOCKETS_NUM)
    {
        int cnt, index;
        struct sal_socket **sockets;

490 491
        /* increase the number of socket with 4 step length */
        cnt = st->max_socket + SOCKET_TABLE_STEP_LEN;
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
        cnt = cnt > SAL_SOCKETS_NUM ? SAL_SOCKETS_NUM : cnt;

        sockets = rt_realloc(st->sockets, cnt * sizeof(struct sal_socket *));
        if (sockets == RT_NULL)
            goto __result; /* return st->max_socket */

        /* clean the new allocated fds */
        for (index = st->max_socket; index < cnt; index++)
        {
            sockets[index] = RT_NULL;
        }

        st->sockets = sockets;
        st->max_socket = cnt;
    }

    /* allocate  'struct sal_socket' */
    if (idx < (int) st->max_socket && st->sockets[idx] == RT_NULL)
    {
511
        st->sockets[idx] = rt_calloc(1, sizeof(struct sal_socket));
512 513 514 515 516 517 518 519 520 521
        if (st->sockets[idx] == RT_NULL)
        {
            idx = st->max_socket;
        }
    }

__result:
    return idx;
}

522 523 524 525 526 527 528 529 530
static void socket_free(struct sal_socket_table *st, int idx)
{
    struct sal_socket *sock;

    sock = st->sockets[idx];
    st->sockets[idx] = RT_NULL;
    rt_free(sock);
}

531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
static int socket_new(void)
{
    struct sal_socket *sock;
    struct sal_socket_table *st = &socket_table;
    int idx;

    sal_lock();

    /* find an empty sal socket entry */
    idx = socket_alloc(st, 0);

    /* can't find an empty sal socket entry */
    if (idx == (int) st->max_socket)
    {
        idx = -(1 + SAL_SOCKET_OFFSET);
        goto __result;
    }

    sock = st->sockets[idx];
    sock->socket = idx + SAL_SOCKET_OFFSET;
    sock->magic = SAL_SOCKET_MAGIC;
552
    sock->netdev = RT_NULL;
553 554 555 556
    sock->user_data = RT_NULL;
#ifdef SAL_USING_TLS
    sock->user_data_tls = RT_NULL;
#endif
557 558 559 560 561 562

__result:
    sal_unlock();
    return idx + SAL_SOCKET_OFFSET;
}

563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582
static void socket_delete(int socket)
{
    struct sal_socket *sock;
    struct sal_socket_table *st = &socket_table;
    int idx;

    idx = socket - SAL_SOCKET_OFFSET;
    if (idx < 0 || idx >= (int) st->max_socket)
    {
        return;
    }
    sal_lock();
    sock = sal_get_socket(socket);
    RT_ASSERT(sock != RT_NULL);
    sock->magic = 0;
    sock->netdev = RT_NULL;
    socket_free(st, idx);
    sal_unlock();
}

583 584 585 586
int sal_accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
{
    int new_socket;
    struct sal_socket *sock;
587
    struct sal_proto_family *pf;
588

589 590
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
591

592
    /* check the network interface is up status */
593 594
    SAL_NETDEV_IS_UP(sock->netdev);

595 596
    /* check the network interface socket operations */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, accept);
597

G
guo 已提交
598
    new_socket = pf->skt_ops->accept((int)(size_t)sock->user_data, addr, addrlen);
599 600 601
    if (new_socket != -1)
    {
        int retval;
602
        int new_sal_socket;
603 604 605
        struct sal_socket *new_sock;

        /* allocate a new socket structure and registered socket options */
606
        new_sal_socket = socket_new();
607 608
        new_sock = sal_get_socket(new_sal_socket);
        if (new_sock == RT_NULL)
609
        {
610
            pf->skt_ops->closesocket(new_socket);
611 612 613 614 615 616
            return -1;
        }

        retval = socket_init(sock->domain, sock->type, sock->protocol, &new_sock);
        if (retval < 0)
        {
617
            pf->skt_ops->closesocket(new_socket);
618
            rt_memset(new_sock, 0x00, sizeof(struct sal_socket));
619 620
            /* socket init failed, delete socket */
            socket_delete(new_sal_socket);
621 622 623
            LOG_E("New socket registered failed, return error %d.", retval);
            return -1;
        }
624

N
NightIsDark 已提交
625 626
        /* new socket create by accept should have the same netdev with server*/
        new_sock->netdev = sock->netdev;
627
        /* socket structure user_data used to store the acquired new socket */
G
guo 已提交
628
        new_sock->user_data = (void *)(size_t)new_socket;
629

630
        return new_sal_socket;
631 632 633 634 635
    }

    return -1;
}

636 637 638 639 640
static void sal_sockaddr_to_ipaddr(const struct sockaddr *name, ip_addr_t *local_ipaddr)
{
    const struct sockaddr_in *svr_addr = (const struct sockaddr_in *) name;

#if NETDEV_IPV4 && NETDEV_IPV6
641 642
    local_ipaddr->u_addr.ip4.addr = svr_addr->sin_addr.s_addr;
    local_ipaddr->type = IPADDR_TYPE_V4;
643
#elif NETDEV_IPV4
644
    local_ipaddr->addr = svr_addr->sin_addr.s_addr;
645
#elif NETDEV_IPV6
646 647
#error "not only support IPV6"
#endif /* NETDEV_IPV4 && NETDEV_IPV6*/
648
}
649

650 651 652
int sal_bind(int socket, const struct sockaddr *name, socklen_t namelen)
{
    struct sal_socket *sock;
653
    struct sal_proto_family *pf;
654
    ip_addr_t input_ipaddr;
655

656 657 658 659 660 661
    RT_ASSERT(name);

    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);

    /* bind network interface by ip address */
662
    sal_sockaddr_to_ipaddr(name, &input_ipaddr);
663 664

    /* check input ipaddr is default netdev ipaddr */
665
    if (!ip_addr_isany_val(input_ipaddr))
666
    {
667 668
        struct sal_proto_family *input_pf = RT_NULL, *local_pf = RT_NULL;
        struct netdev *new_netdev = RT_NULL;
669

670 671 672 673 674
        new_netdev = netdev_get_by_ipaddr(&input_ipaddr);
        if (new_netdev == RT_NULL)
        {
            return -1;
        }
675

676 677 678
        /* get input and local ip address proto_family */
        SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, local_pf, bind);
        SAL_NETDEV_SOCKETOPS_VALID(new_netdev, input_pf, bind);
679

680 681
        /* check the network interface protocol family type */
        if (input_pf->family != local_pf->family)
682
        {
683 684 685 686 687 688 689 690 691 692 693
            int new_socket = -1;

            /* protocol family is different, close old socket and create new socket by input ip address */
            local_pf->skt_ops->closesocket(socket);

            new_socket = input_pf->skt_ops->socket(input_pf->family, sock->type, sock->protocol);
            if (new_socket < 0)
            {
                return -1;
            }
            sock->netdev = new_netdev;
G
guo 已提交
694
            sock->user_data = (void *)(size_t)new_socket;
695
        }
696
    }
697

698
    /* check and get protocol families by the network interface device */
699
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, bind);
G
guo 已提交
700
    return pf->skt_ops->bind((int)(size_t)sock->user_data, name, namelen);
701 702 703 704 705
}

int sal_shutdown(int socket, int how)
{
    struct sal_socket *sock;
706 707
    struct sal_proto_family *pf;
    int error = 0;
708

709 710
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
711

L
luhuadong 已提交
712
    /* shutdown operation not need to check network interface status */
713 714
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, shutdown);
715

G
guo 已提交
716
    if (pf->skt_ops->shutdown((int)(size_t)sock->user_data, how) == 0)
717
    {
718 719 720 721 722 723 724 725 726
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, closesocket))
        {
            if (proto_tls->ops->closesocket(sock->user_data_tls) < 0)
            {
                return -1;
            }
        }
#endif
727 728 729 730 731
        error = 0;
    }
    else
    {
        error = -1;
732 733
    }

734 735

    return error;
736 737 738 739 740
}

int sal_getpeername(int socket, struct sockaddr *name, socklen_t *namelen)
{
    struct sal_socket *sock;
741
    struct sal_proto_family *pf;
742

743 744
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
745

746 747
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getpeername);
748

G
guo 已提交
749
    return pf->skt_ops->getpeername((int)(size_t)sock->user_data, name, namelen);
750 751 752 753 754
}

int sal_getsockname(int socket, struct sockaddr *name, socklen_t *namelen)
{
    struct sal_socket *sock;
755
    struct sal_proto_family *pf;
756

757 758
    /* get socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
759

760 761
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getsockname);
762

G
guo 已提交
763
    return pf->skt_ops->getsockname((int)(size_t)sock->user_data, name, namelen);
764 765 766 767 768
}

int sal_getsockopt(int socket, int level, int optname, void *optval, socklen_t *optlen)
{
    struct sal_socket *sock;
769
    struct sal_proto_family *pf;
770

771 772
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
773

774 775
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getsockopt);
776

G
guo 已提交
777
    return pf->skt_ops->getsockopt((int)(size_t)sock->user_data, level, optname, optval, optlen);
778 779 780 781 782
}

int sal_setsockopt(int socket, int level, int optname, const void *optval, socklen_t optlen)
{
    struct sal_socket *sock;
783
    struct sal_proto_family *pf;
784

785 786
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
787

788 789
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, setsockopt);
790

791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819
#ifdef SAL_USING_TLS
    if (level == SOL_TLS)
    {
        switch (optname)
        {
        case TLS_CRET_LIST:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_cret_list, optval, optlen);
            break;

        case TLS_CIPHERSUITE_LIST:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_ciphersurite, optval, optlen);
            break;

        case TLS_PEER_VERIFY:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_peer_verify, optval, optlen);
            break;

        case TLS_DTLS_ROLE:
            SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_dtls_role, optval, optlen);
            break;

        default:
            return -1;
        }

        return 0;
    }
    else
    {
820
        return pf->skt_ops->setsockopt((int) sock->user_data, level, optname, optval, optlen);
821 822
    }
#else
G
guo 已提交
823
    return pf->skt_ops->setsockopt((int)(size_t)sock->user_data, level, optname, optval, optlen);
824
#endif /* SAL_USING_TLS */
825 826 827 828 829
}

int sal_connect(int socket, const struct sockaddr *name, socklen_t namelen)
{
    struct sal_socket *sock;
830
    struct sal_proto_family *pf;
831
    int ret;
832

833 834
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
835

836 837
    /* check the network interface is up status */
    SAL_NETDEV_IS_UP(sock->netdev);
838 839
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, connect);
840

G
guo 已提交
841
    ret = pf->skt_ops->connect((int)(size_t)sock->user_data, name, namelen);
842 843 844 845 846 847 848
#ifdef SAL_USING_TLS
    if (ret >= 0 && SAL_SOCKOPS_PROTO_TLS_VALID(sock, connect))
    {
        if (proto_tls->ops->connect(sock->user_data_tls) < 0)
        {
            return -1;
        }
849

850 851 852 853 854
        return ret;
    }
#endif

    return ret;
855 856 857 858 859
}

int sal_listen(int socket, int backlog)
{
    struct sal_socket *sock;
860
    struct sal_proto_family *pf;
861

862 863
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
864

865 866
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, listen);
867

G
guo 已提交
868
    return pf->skt_ops->listen((int)(size_t)sock->user_data, backlog);
869 870 871
}

int sal_recvfrom(int socket, void *mem, size_t len, int flags,
872
                 struct sockaddr *from, socklen_t *fromlen)
873 874
{
    struct sal_socket *sock;
875
    struct sal_proto_family *pf;
876

877 878
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
879

880 881
    /* check the network interface is up status  */
    SAL_NETDEV_IS_UP(sock->netdev);
882 883
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, recvfrom);
884

885 886 887 888
#ifdef SAL_USING_TLS
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, recv))
    {
        int ret;
889

890 891 892
        if ((ret = proto_tls->ops->recv(sock->user_data_tls, mem, len)) < 0)
        {
            return -1;
893
        }
894 895 896 897
        return ret;
    }
    else
    {
G
guo 已提交
898
        return pf->skt_ops->recvfrom((int)(size_t)sock->user_data, mem, len, flags, from, fromlen);
899 900
    }
#else
G
guo 已提交
901
    return pf->skt_ops->recvfrom((int)(size_t)sock->user_data, mem, len, flags, from, fromlen);
902
#endif
903 904 905
}

int sal_sendto(int socket, const void *dataptr, size_t size, int flags,
906
               const struct sockaddr *to, socklen_t tolen)
907 908
{
    struct sal_socket *sock;
909
    struct sal_proto_family *pf;
910

911 912
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
913

914 915
    /* check the network interface is up status  */
    SAL_NETDEV_IS_UP(sock->netdev);
916 917
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, sendto);
918

919 920 921 922
#ifdef SAL_USING_TLS
    if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, send))
    {
        int ret;
923

924 925 926
        if ((ret = proto_tls->ops->send(sock->user_data_tls, dataptr, size)) < 0)
        {
            return -1;
927
        }
928 929 930 931
        return ret;
    }
    else
    {
932
        return pf->skt_ops->sendto((int) sock->user_data, dataptr, size, flags, to, tolen);
933 934
    }
#else
G
guo 已提交
935
    return pf->skt_ops->sendto((int)(size_t)sock->user_data, dataptr, size, flags, to, tolen);
936
#endif
937 938 939 940 941 942 943
}

int sal_socket(int domain, int type, int protocol)
{
    int retval;
    int socket, proto_socket;
    struct sal_socket *sock;
944
    struct sal_proto_family *pf;
945 946 947 948 949 950 951

    /* allocate a new socket and registered socket options */
    socket = socket_new();
    if (socket < 0)
    {
        return -1;
    }
952 953

    /* get sal socket object by socket descriptor */
954
    sock = sal_get_socket(socket);
955 956
    if (sock == RT_NULL)
    {
957
        socket_delete(socket);
958 959
        return -1;
    }
960

961
    /* Initialize sal socket object */
962 963 964 965
    retval = socket_init(domain, type, protocol, &sock);
    if (retval < 0)
    {
        LOG_E("SAL socket protocol family input failed, return error %d.", retval);
966
        socket_delete(socket);
967 968 969
        return -1;
    }

970 971
    /* valid the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, socket);
972

973
    proto_socket = pf->skt_ops->socket(domain, type, protocol);
974 975
    if (proto_socket >= 0)
    {
976 977 978
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, socket))
        {
979
            sock->user_data_tls = proto_tls->ops->socket(socket);
980 981
            if (sock->user_data_tls == RT_NULL)
            {
982
                socket_delete(socket);
983 984 985 986
                return -1;
            }
        }
#endif
G
guo 已提交
987
        sock->user_data = (void *)(size_t)proto_socket;
988 989
        return sock->socket;
    }
990
    socket_delete(socket);
991 992 993 994 995 996
    return -1;
}

int sal_closesocket(int socket)
{
    struct sal_socket *sock;
997 998
    struct sal_proto_family *pf;
    int error = 0;
999

1000 1001
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
1002

L
luhuadong 已提交
1003
    /* clsoesocket operation not need to vaild network interface status */
1004
    /* valid the network interface socket opreation */
C
caixf 已提交
1005
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, closesocket);
1006

G
guo 已提交
1007
    if (pf->skt_ops->closesocket((int)(size_t)sock->user_data) == 0)
1008
    {
1009 1010 1011 1012 1013 1014 1015 1016 1017
#ifdef SAL_USING_TLS
        if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, closesocket))
        {
            if (proto_tls->ops->closesocket(sock->user_data_tls) < 0)
            {
                return -1;
            }
        }
#endif
1018 1019 1020 1021 1022
        error = 0;
    }
    else
    {
        error = -1;
1023 1024
    }

1025 1026
    /* delete socket */
    socket_delete(socket);
1027 1028

    return error;
1029 1030 1031 1032 1033
}

int sal_ioctlsocket(int socket, long cmd, void *arg)
{
    struct sal_socket *sock;
1034
    struct sal_proto_family *pf;
G
guo 已提交
1035 1036 1037
    struct sockaddr_in *addr_in = RT_NULL;
    struct sockaddr *addr = RT_NULL;
    ip_addr_t input_ipaddr;
1038 1039
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
1040

1041 1042
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, ioctlsocket);
1043

G
guo 已提交
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064
    struct sal_ifreq *ifr = (struct sal_ifreq *)arg;

    if((sock->domain == AF_INET)&&(sock->netdev)&&(ifr != RT_NULL))
    {
        switch (cmd)
        {
        case SIOCGIFADDR:
            addr_in = (struct sockaddr_in *)&(ifr->ifr_ifru.ifru_addr);
#if NETDEV_IPV4 && NETDEV_IPV6
            addr_in->sin_addr.s_addr = sock->netdev->ip_addr.u_addr.ip4.addr;
#elif NETDEV_IPV4
            addr_in->sin_addr.s_addr = sock->netdev->ip_addr.addr;
#elif NETDEV_IPV6
#error "not only support IPV6"
#endif /* NETDEV_IPV4 && NETDEV_IPV6*/
            return 0;

        case SIOCSIFADDR:
            addr = (struct sockaddr *)&(ifr->ifr_ifru.ifru_addr);
            sal_sockaddr_to_ipaddr(addr,&input_ipaddr);
            netdev_set_ipaddr(sock->netdev,&input_ipaddr);
1065
            return 0;
G
guo 已提交
1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081

        case SIOCGIFNETMASK:
            addr_in = (struct sockaddr_in *)&(ifr->ifr_ifru.ifru_netmask);
#if NETDEV_IPV4 && NETDEV_IPV6
            addr_in->sin_addr.s_addr = sock->netdev->netmask.u_addr.ip4.addr;
#elif NETDEV_IPV4
            addr_in->sin_addr.s_addr = sock->netdev->netmask.addr;
#elif NETDEV_IPV6
#error "not only support IPV6"
#endif /* NETDEV_IPV4 && NETDEV_IPV6*/
            return 0;

        case SIOCSIFNETMASK:
            addr = (struct sockaddr *)&(ifr->ifr_ifru.ifru_netmask);
            sal_sockaddr_to_ipaddr(addr,&input_ipaddr);
            netdev_set_netmask(sock->netdev,&input_ipaddr);
1082
            return 0;
G
guo 已提交
1083 1084 1085 1086 1087 1088 1089 1090

        case SIOCGIFHWADDR:
            addr = (struct sockaddr *)&(ifr->ifr_ifru.ifru_hwaddr);
            rt_memcpy(addr->sa_data,sock->netdev->hwaddr,sock->netdev->hwaddr_len);
            return 0;

        case SIOCGIFMTU:
            ifr->ifr_ifru.ifru_mtu = sock->netdev->mtu;
1091
            return 0;
G
guo 已提交
1092 1093 1094 1095 1096 1097

        default:
            break;
        }
    }
    return pf->skt_ops->ioctlsocket((int)(size_t)sock->user_data, cmd, arg);
1098 1099
}

1100
#ifdef SAL_USING_POSIX
1101
int sal_poll(struct dfs_file *file, struct rt_pollreq *req)
1102 1103
{
    struct sal_socket *sock;
1104
    struct sal_proto_family *pf;
G
guo 已提交
1105
    int socket = (int)(size_t)file->vnode->data;
1106

1107 1108
    /* get the socket object by socket descriptor */
    SAL_SOCKET_OBJ_GET(sock, socket);
1109

1110 1111
    /* check the network interface is up status  */
    SAL_NETDEV_IS_UP(sock->netdev);
1112 1113
    /* check the network interface socket opreation */
    SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, poll);
1114

1115
    return pf->skt_ops->poll(file, req);
1116
}
1117
#endif
1118 1119 1120

struct hostent *sal_gethostbyname(const char *name)
{
1121 1122
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1123

1124
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname))
1125
    {
1126 1127 1128 1129
        return pf->netdb_ops->gethostbyname(name);
    }
    else
    {
1130 1131
        /* get the first network interface device with up status */
        netdev = netdev_get_first_by_flags(NETDEV_FLAG_UP);
1132
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname))
1133
        {
1134
            return pf->netdb_ops->gethostbyname(name);
1135 1136 1137 1138 1139 1140 1141
        }
    }

    return RT_NULL;
}

int sal_gethostbyname_r(const char *name, struct hostent *ret, char *buf,
1142
                        size_t buflen, struct hostent **result, int *h_errnop)
1143
{
1144 1145
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1146

1147
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname_r))
1148
    {
1149 1150 1151 1152
        return pf->netdb_ops->gethostbyname_r(name, ret, buf, buflen, result, h_errnop);
    }
    else
    {
1153 1154
        /* get the first network interface device with up status */
        netdev = netdev_get_first_by_flags(NETDEV_FLAG_UP);
1155
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, gethostbyname_r))
1156
        {
1157
            return pf->netdb_ops->gethostbyname_r(name, ret, buf, buflen, result, h_errnop);
1158 1159 1160 1161 1162 1163 1164
        }
    }

    return -1;
}

int sal_getaddrinfo(const char *nodename,
1165 1166 1167
                    const char *servname,
                    const struct addrinfo *hints,
                    struct addrinfo **res)
1168
{
1169 1170
    struct netdev *netdev = netdev_default;
    struct sal_proto_family *pf;
1171 1172
    int     ret = 0;
    rt_uint32_t i = 0;
1173

1174
    if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, getaddrinfo))
1175
    {
1176
        ret = pf->netdb_ops->getaddrinfo(nodename, servname, hints, res);
1177 1178 1179
    }
    else
    {
1180 1181
        /* get the first network interface device with up status */
        netdev = netdev_get_first_by_flags(NETDEV_FLAG_UP);
1182
        if (SAL_NETDEV_NETDBOPS_VALID(netdev, pf, getaddrinfo))
1183
        {
1184 1185 1186 1187 1188
            ret = pf->netdb_ops->getaddrinfo(nodename, servname, hints, res);
        }
        else
        {
            ret = -1;
1189 1190 1191
        }
    }

1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205
    if(ret == RT_EOK)
    {
        /*record the netdev and res*/
        for(i = 0; i < SAL_SOCKETS_NUM; i++)
        {
            if(sal_dev_res_tbl[i].res == RT_NULL)
            {
                sal_dev_res_tbl[i].res = *res;
                sal_dev_res_tbl[i].netdev = netdev;
                break;
            }
        }

        RT_ASSERT((i < SAL_SOCKETS_NUM));
1206

1207 1208 1209
    }

    return ret;
1210
}
1211 1212 1213

void sal_freeaddrinfo(struct addrinfo *ai)
{
1214 1215 1216
    struct netdev *netdev = RT_NULL;
    struct sal_proto_family *pf = RT_NULL;
    rt_uint32_t  i = 0;
1217

1218 1219
    /*when use the multi netdev, it must free the ai use the getaddrinfo netdev */
    for(i = 0; i < SAL_SOCKETS_NUM; i++)
1220
    {
1221
        if(sal_dev_res_tbl[i].res == ai)
1222
        {
1223 1224 1225 1226
            netdev = sal_dev_res_tbl[i].netdev;
            sal_dev_res_tbl[i].res = RT_NULL;
            sal_dev_res_tbl[i].netdev = RT_NULL;
            break;
1227 1228
        }
    }
1229 1230 1231 1232 1233 1234
    RT_ASSERT((i < SAL_SOCKETS_NUM));

    if (SAL_NETDBOPS_VALID(netdev, pf, freeaddrinfo))
    {
        pf->netdb_ops->freeaddrinfo(ai);
    }
1235
}