diff --git a/MAINTAINERS b/MAINTAINERS index d0cbb3d7a0ca8c383f703d86ec04160d7047a2a2..f0c37be4e04a6cb1077e08f5b0781e4dbef48616 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -14286,12 +14286,15 @@ S: Maintained F: include/linux/virtio_vsock.h F: include/uapi/linux/virtio_vsock.h F: include/uapi/linux/vsockmon.h +F: include/uapi/linux/vm_sockets_diag.h +F: net/vmw_vsock/diag.c F: net/vmw_vsock/af_vsock_tap.c F: net/vmw_vsock/virtio_transport_common.c F: net/vmw_vsock/virtio_transport.c F: drivers/net/vsockmon.c F: drivers/vhost/vsock.c F: drivers/vhost/vsock.h +F: tools/testing/vsock/ VIRTIO CONSOLE DRIVER M: Amit Shah diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h index f9fb566e75cfd8ca1531dd733d443f9a6c929c62..9324ac2d9ff2db234cd5f3b8bf9f7e1a2bc60e52 100644 --- a/include/net/af_vsock.h +++ b/include/net/af_vsock.h @@ -22,11 +22,13 @@ #include "vsock_addr.h" -/* vsock-specific sock->sk_state constants */ -#define VSOCK_SS_LISTEN 255 - #define LAST_RESERVED_PORT 1023 +#define VSOCK_HASH_SIZE 251 +extern struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1]; +extern struct list_head vsock_connected_table[VSOCK_HASH_SIZE]; +extern spinlock_t vsock_table_lock; + #define vsock_sk(__sk) ((struct vsock_sock *)__sk) #define sk_vsock(__vsk) (&(__vsk)->sk) @@ -175,6 +177,18 @@ const struct vsock_transport *vsock_core_get_transport(void); /**** UTILS ****/ +/* vsock_table_lock must be held */ +static inline bool __vsock_in_bound_table(struct vsock_sock *vsk) +{ + return !list_empty(&vsk->bound_table); +} + +/* vsock_table_lock must be held */ +static inline bool __vsock_in_connected_table(struct vsock_sock *vsk) +{ + return !list_empty(&vsk->connected_table); +} + void vsock_release_pending(struct sock *pending); void vsock_add_pending(struct sock *listener, struct sock *pending); void vsock_remove_pending(struct sock *listener, struct sock *pending); diff --git a/include/uapi/linux/vm_sockets_diag.h b/include/uapi/linux/vm_sockets_diag.h new file mode 100644 index 0000000000000000000000000000000000000000..14cd7dc5a187c85fea9d98089bee1fcf785971ed --- /dev/null +++ b/include/uapi/linux/vm_sockets_diag.h @@ -0,0 +1,33 @@ +/* AF_VSOCK sock_diag(7) interface for querying open sockets */ + +#ifndef _UAPI__VM_SOCKETS_DIAG_H__ +#define _UAPI__VM_SOCKETS_DIAG_H__ + +#include + +/* Request */ +struct vsock_diag_req { + __u8 sdiag_family; /* must be AF_VSOCK */ + __u8 sdiag_protocol; /* must be 0 */ + __u16 pad; /* must be 0 */ + __u32 vdiag_states; /* query bitmap (e.g. 1 << TCP_LISTEN) */ + __u32 vdiag_ino; /* must be 0 (reserved) */ + __u32 vdiag_show; /* must be 0 (reserved) */ + __u32 vdiag_cookie[2]; +}; + +/* Response */ +struct vsock_diag_msg { + __u8 vdiag_family; /* AF_VSOCK */ + __u8 vdiag_type; /* SOCK_STREAM or SOCK_DGRAM */ + __u8 vdiag_state; /* sk_state (e.g. TCP_LISTEN) */ + __u8 vdiag_shutdown; /* local RCV_SHUTDOWN | SEND_SHUTDOWN */ + __u32 vdiag_src_cid; + __u32 vdiag_src_port; + __u32 vdiag_dst_cid; + __u32 vdiag_dst_port; + __u32 vdiag_ino; + __u32 vdiag_cookie[2]; +}; + +#endif /* _UAPI__VM_SOCKETS_DIAG_H__ */ diff --git a/net/vmw_vsock/Kconfig b/net/vmw_vsock/Kconfig index a24369d175fd6564fadcc56fd3d21e740c9eabc1..970f96489fe766ce5577607643289f4ee02e4f9a 100644 --- a/net/vmw_vsock/Kconfig +++ b/net/vmw_vsock/Kconfig @@ -15,6 +15,16 @@ config VSOCKETS To compile this driver as a module, choose M here: the module will be called vsock. If unsure, say N. +config VSOCKETS_DIAG + tristate "Virtual Sockets monitoring interface" + depends on VSOCKETS + default y + help + Support for PF_VSOCK sockets monitoring interface used by the ss tool. + If unsure, say Y. + + Enable this module so userspace applications can query open sockets. + config VMWARE_VMCI_VSOCKETS tristate "VMware VMCI transport for Virtual Sockets" depends on VSOCKETS && VMWARE_VMCI diff --git a/net/vmw_vsock/Makefile b/net/vmw_vsock/Makefile index e63d574234a98974be767ae9b3a081cf4949d7ff..64afc06805da37eb663236b5b320684aee85f6ee 100644 --- a/net/vmw_vsock/Makefile +++ b/net/vmw_vsock/Makefile @@ -1,4 +1,5 @@ obj-$(CONFIG_VSOCKETS) += vsock.o +obj-$(CONFIG_VSOCKETS_DIAG) += vsock_diag.o obj-$(CONFIG_VMWARE_VMCI_VSOCKETS) += vmw_vsock_vmci_transport.o obj-$(CONFIG_VIRTIO_VSOCKETS) += vmw_vsock_virtio_transport.o obj-$(CONFIG_VIRTIO_VSOCKETS_COMMON) += vmw_vsock_virtio_transport_common.o @@ -6,6 +7,8 @@ obj-$(CONFIG_HYPERV_VSOCKETS) += hv_sock.o vsock-y += af_vsock.o af_vsock_tap.o vsock_addr.o +vsock_diag-y += diag.o + vmw_vsock_vmci_transport-y += vmci_transport.o vmci_transport_notify.o \ vmci_transport_notify_qstate.o diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index dfc8c51e4d74ec378a338ab9bb2560b3811f393b..98359c19522f6ecbc3cd052f4389a6884bbff0cc 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -36,7 +36,7 @@ * not support simultaneous connects (two "client" sockets connecting). * * - "Server" sockets are referred to as listener sockets throughout this - * implementation because they are in the VSOCK_SS_LISTEN state. When a + * implementation because they are in the TCP_LISTEN state. When a * connection request is received (the second kind of socket mentioned above), * we create a new socket and refer to it as a pending socket. These pending * sockets are placed on the pending connection list of the listener socket. @@ -82,6 +82,15 @@ * argument, we must ensure the reference count is increased to ensure the * socket isn't freed before the function is run; the deferred function will * then drop the reference. + * + * - sk->sk_state uses the TCP state constants because they are widely used by + * other address families and exposed to userspace tools like ss(8): + * + * TCP_CLOSE - unconnected + * TCP_SYN_SENT - connecting + * TCP_ESTABLISHED - connected + * TCP_CLOSING - disconnecting + * TCP_LISTEN - listening */ #include @@ -153,7 +162,6 @@ EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid); * vsock_bind_table[VSOCK_HASH_SIZE] is for unbound sockets. The hash function * mods with VSOCK_HASH_SIZE to ensure this. */ -#define VSOCK_HASH_SIZE 251 #define MAX_PORT_RETRIES 24 #define VSOCK_HASH(addr) ((addr)->svm_port % VSOCK_HASH_SIZE) @@ -168,9 +176,12 @@ EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid); #define vsock_connected_sockets_vsk(vsk) \ vsock_connected_sockets(&(vsk)->remote_addr, &(vsk)->local_addr) -static struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1]; -static struct list_head vsock_connected_table[VSOCK_HASH_SIZE]; -static DEFINE_SPINLOCK(vsock_table_lock); +struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1]; +EXPORT_SYMBOL_GPL(vsock_bind_table); +struct list_head vsock_connected_table[VSOCK_HASH_SIZE]; +EXPORT_SYMBOL_GPL(vsock_connected_table); +DEFINE_SPINLOCK(vsock_table_lock); +EXPORT_SYMBOL_GPL(vsock_table_lock); /* Autobind this socket to the local address if necessary. */ static int vsock_auto_bind(struct vsock_sock *vsk) @@ -248,16 +259,6 @@ static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src, return NULL; } -static bool __vsock_in_bound_table(struct vsock_sock *vsk) -{ - return !list_empty(&vsk->bound_table); -} - -static bool __vsock_in_connected_table(struct vsock_sock *vsk) -{ - return !list_empty(&vsk->connected_table); -} - static void vsock_insert_unbound(struct vsock_sock *vsk) { spin_lock_bh(&vsock_table_lock); @@ -485,7 +486,7 @@ void vsock_pending_work(struct work_struct *work) if (vsock_in_connected_table(vsk)) vsock_remove_connected(vsk); - sk->sk_state = SS_FREE; + sk->sk_state = TCP_CLOSE; out: release_sock(sk); @@ -625,7 +626,6 @@ struct sock *__vsock_create(struct net *net, sk->sk_destruct = vsock_sk_destruct; sk->sk_backlog_rcv = vsock_queue_rcv_skb; - sk->sk_state = 0; sock_reset_flag(sk, SOCK_DONE); INIT_LIST_HEAD(&vsk->bound_table); @@ -899,7 +899,7 @@ static unsigned int vsock_poll(struct file *file, struct socket *sock, /* Listening sockets that have connections in their accept * queue can be read. */ - if (sk->sk_state == VSOCK_SS_LISTEN + if (sk->sk_state == TCP_LISTEN && !vsock_is_accept_queue_empty(sk)) mask |= POLLIN | POLLRDNORM; @@ -928,7 +928,7 @@ static unsigned int vsock_poll(struct file *file, struct socket *sock, } /* Connected sockets that can produce data can be written. */ - if (sk->sk_state == SS_CONNECTED) { + if (sk->sk_state == TCP_ESTABLISHED) { if (!(sk->sk_shutdown & SEND_SHUTDOWN)) { bool space_avail_now = false; int ret = transport->notify_poll_out( @@ -950,7 +950,7 @@ static unsigned int vsock_poll(struct file *file, struct socket *sock, * POLLOUT|POLLWRNORM when peer is closed and nothing to read, * but local send is not shutdown. */ - if (sk->sk_state == SS_UNCONNECTED) { + if (sk->sk_state == TCP_CLOSE) { if (!(sk->sk_shutdown & SEND_SHUTDOWN)) mask |= POLLOUT | POLLWRNORM; @@ -1120,9 +1120,9 @@ static void vsock_connect_timeout(struct work_struct *work) sk = sk_vsock(vsk); lock_sock(sk); - if (sk->sk_state == SS_CONNECTING && + if (sk->sk_state == TCP_SYN_SENT && (sk->sk_shutdown != SHUTDOWN_MASK)) { - sk->sk_state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; sk->sk_err = ETIMEDOUT; sk->sk_error_report(sk); cancel = 1; @@ -1168,7 +1168,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, err = -EALREADY; break; default: - if ((sk->sk_state == VSOCK_SS_LISTEN) || + if ((sk->sk_state == TCP_LISTEN) || vsock_addr_cast(addr, addr_len, &remote_addr) != 0) { err = -EINVAL; goto out; @@ -1191,7 +1191,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, if (err) goto out; - sk->sk_state = SS_CONNECTING; + sk->sk_state = TCP_SYN_SENT; err = transport->connect(vsk); if (err < 0) @@ -1211,7 +1211,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, timeout = vsk->connect_timeout; prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); - while (sk->sk_state != SS_CONNECTED && sk->sk_err == 0) { + while (sk->sk_state != TCP_ESTABLISHED && sk->sk_err == 0) { if (flags & O_NONBLOCK) { /* If we're not going to block, we schedule a timeout * function to generate a timeout on the connection @@ -1234,13 +1234,13 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, if (signal_pending(current)) { err = sock_intr_errno(timeout); - sk->sk_state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; sock->state = SS_UNCONNECTED; vsock_transport_cancel_pkt(vsk); goto out_wait; } else if (timeout == 0) { err = -ETIMEDOUT; - sk->sk_state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; sock->state = SS_UNCONNECTED; vsock_transport_cancel_pkt(vsk); goto out_wait; @@ -1251,7 +1251,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, if (sk->sk_err) { err = -sk->sk_err; - sk->sk_state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; sock->state = SS_UNCONNECTED; } else { err = 0; @@ -1284,7 +1284,7 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags, goto out; } - if (listener->sk_state != VSOCK_SS_LISTEN) { + if (listener->sk_state != TCP_LISTEN) { err = -EINVAL; goto out; } @@ -1374,7 +1374,7 @@ static int vsock_listen(struct socket *sock, int backlog) } sk->sk_max_ack_backlog = backlog; - sk->sk_state = VSOCK_SS_LISTEN; + sk->sk_state = TCP_LISTEN; err = 0; @@ -1554,7 +1554,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, /* Callers should not provide a destination with stream sockets. */ if (msg->msg_namelen) { - err = sk->sk_state == SS_CONNECTED ? -EISCONN : -EOPNOTSUPP; + err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP; goto out; } @@ -1565,7 +1565,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, goto out; } - if (sk->sk_state != SS_CONNECTED || + if (sk->sk_state != TCP_ESTABLISHED || !vsock_addr_bound(&vsk->local_addr)) { err = -ENOTCONN; goto out; @@ -1689,7 +1689,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, lock_sock(sk); - if (sk->sk_state != SS_CONNECTED) { + if (sk->sk_state != TCP_ESTABLISHED) { /* Recvmsg is supposed to return 0 if a peer performs an * orderly shutdown. Differentiate between that case and when a * peer has not connected or a local shutdown occured with the diff --git a/net/vmw_vsock/diag.c b/net/vmw_vsock/diag.c new file mode 100644 index 0000000000000000000000000000000000000000..31b5676522503bce482287588f99bfc837d7eccb --- /dev/null +++ b/net/vmw_vsock/diag.c @@ -0,0 +1,186 @@ +/* + * vsock sock_diag(7) module + * + * Copyright (C) 2017 Red Hat, Inc. + * Author: Stefan Hajnoczi + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the Free + * Software Foundation version 2 and no later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + */ + +#include +#include +#include +#include + +static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, + u32 portid, u32 seq, u32 flags) +{ + struct vsock_sock *vsk = vsock_sk(sk); + struct vsock_diag_msg *rep; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep), + flags); + if (!nlh) + return -EMSGSIZE; + + rep = nlmsg_data(nlh); + rep->vdiag_family = AF_VSOCK; + + /* Lock order dictates that sk_lock is acquired before + * vsock_table_lock, so we cannot lock here. Simply don't take + * sk_lock; sk is guaranteed to stay alive since vsock_table_lock is + * held. + */ + rep->vdiag_type = sk->sk_type; + rep->vdiag_state = sk->sk_state; + rep->vdiag_shutdown = sk->sk_shutdown; + rep->vdiag_src_cid = vsk->local_addr.svm_cid; + rep->vdiag_src_port = vsk->local_addr.svm_port; + rep->vdiag_dst_cid = vsk->remote_addr.svm_cid; + rep->vdiag_dst_port = vsk->remote_addr.svm_port; + rep->vdiag_ino = sock_i_ino(sk); + + sock_diag_save_cookie(sk, rep->vdiag_cookie); + + return 0; +} + +static int vsock_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct vsock_diag_req *req; + struct vsock_sock *vsk; + unsigned int bucket; + unsigned int last_i; + unsigned int table; + struct net *net; + unsigned int i; + + req = nlmsg_data(cb->nlh); + net = sock_net(skb->sk); + + /* State saved between calls: */ + table = cb->args[0]; + bucket = cb->args[1]; + i = last_i = cb->args[2]; + + /* TODO VMCI pending sockets? */ + + spin_lock_bh(&vsock_table_lock); + + /* Bind table (locally created sockets) */ + if (table == 0) { + while (bucket < ARRAY_SIZE(vsock_bind_table)) { + struct list_head *head = &vsock_bind_table[bucket]; + + i = 0; + list_for_each_entry(vsk, head, bound_table) { + struct sock *sk = sk_vsock(vsk); + + if (!net_eq(sock_net(sk), net)) + continue; + if (i < last_i) + goto next_bind; + if (!(req->vdiag_states & (1 << sk->sk_state))) + goto next_bind; + if (sk_diag_fill(sk, skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI) < 0) + goto done; +next_bind: + i++; + } + last_i = 0; + bucket++; + } + + table++; + bucket = 0; + } + + /* Connected table (accepted connections) */ + while (bucket < ARRAY_SIZE(vsock_connected_table)) { + struct list_head *head = &vsock_connected_table[bucket]; + + i = 0; + list_for_each_entry(vsk, head, connected_table) { + struct sock *sk = sk_vsock(vsk); + + /* Skip sockets we've already seen above */ + if (__vsock_in_bound_table(vsk)) + continue; + + if (!net_eq(sock_net(sk), net)) + continue; + if (i < last_i) + goto next_connected; + if (!(req->vdiag_states & (1 << sk->sk_state))) + goto next_connected; + if (sk_diag_fill(sk, skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI) < 0) + goto done; +next_connected: + i++; + } + last_i = 0; + bucket++; + } + +done: + spin_unlock_bh(&vsock_table_lock); + + cb->args[0] = table; + cb->args[1] = bucket; + cb->args[2] = i; + + return skb->len; +} + +static int vsock_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) +{ + int hdrlen = sizeof(struct vsock_diag_req); + struct net *net = sock_net(skb->sk); + + if (nlmsg_len(h) < hdrlen) + return -EINVAL; + + if (h->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = vsock_diag_dump, + }; + return netlink_dump_start(net->diag_nlsk, skb, h, &c); + } + + return -EOPNOTSUPP; +} + +static const struct sock_diag_handler vsock_diag_handler = { + .family = AF_VSOCK, + .dump = vsock_diag_handler_dump, +}; + +static int __init vsock_diag_init(void) +{ + return sock_diag_register(&vsock_diag_handler); +} + +static void __exit vsock_diag_exit(void) +{ + sock_diag_unregister(&vsock_diag_handler); +} + +module_init(vsock_diag_init); +module_exit(vsock_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, + 40 /* AF_VSOCK */); diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index 14ed5a344cdf302ba3f2d8e9dec4fb7c66fdd239..bbac023e70d1741ab2ad6a33d9ee7a954cf0e33c 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -310,7 +310,7 @@ static void hvs_close_connection(struct vmbus_channel *chan) struct sock *sk = get_per_channel_state(chan); struct vsock_sock *vsk = vsock_sk(sk); - sk->sk_state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; sock_set_flag(sk, SOCK_DONE); vsk->peer_shutdown |= SEND_SHUTDOWN | RCV_SHUTDOWN; @@ -344,8 +344,8 @@ static void hvs_open_connection(struct vmbus_channel *chan) if (!sk) return; - if ((conn_from_host && sk->sk_state != VSOCK_SS_LISTEN) || - (!conn_from_host && sk->sk_state != SS_CONNECTING)) + if ((conn_from_host && sk->sk_state != TCP_LISTEN) || + (!conn_from_host && sk->sk_state != TCP_SYN_SENT)) goto out; if (conn_from_host) { @@ -357,7 +357,7 @@ static void hvs_open_connection(struct vmbus_channel *chan) if (!new) goto out; - new->sk_state = SS_CONNECTING; + new->sk_state = TCP_SYN_SENT; vnew = vsock_sk(new); hvs_new = vnew->trans; hvs_new->chan = chan; @@ -384,7 +384,7 @@ static void hvs_open_connection(struct vmbus_channel *chan) vmbus_set_chn_rescind_callback(chan, hvs_close_connection); if (conn_from_host) { - new->sk_state = SS_CONNECTED; + new->sk_state = TCP_ESTABLISHED; sk->sk_ack_backlog++; hvs_addr_init(&vnew->local_addr, if_type); @@ -399,7 +399,7 @@ static void hvs_open_connection(struct vmbus_channel *chan) vsock_enqueue_accept(sk, new); release_sock(sk); } else { - sk->sk_state = SS_CONNECTED; + sk->sk_state = TCP_ESTABLISHED; sk->sk_socket->state = SS_CONNECTED; vsock_insert_connected(vsock_sk(sk)); diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index 403d86e80162e7796fd75249b1ae876d1eee1e6a..8e03bd3f3668b573c4d61a786e90a238abe9fe66 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -414,7 +414,7 @@ static void virtio_vsock_event_fill(struct virtio_vsock *vsock) static void virtio_vsock_reset_sock(struct sock *sk) { lock_sock(sk); - sk->sk_state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; sk->sk_err = ECONNRESET; sk->sk_error_report(sk); release_sock(sk); diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index edba7ab975639fc08c4257d0393391477979b1aa..3ae3a33da70bab034c29552c1f7e04b9b3d52c4d 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -708,7 +708,7 @@ static void virtio_transport_do_close(struct vsock_sock *vsk, sock_set_flag(sk, SOCK_DONE); vsk->peer_shutdown = SHUTDOWN_MASK; if (vsock_stream_has_data(vsk) <= 0) - sk->sk_state = SS_DISCONNECTING; + sk->sk_state = TCP_CLOSING; sk->sk_state_change(sk); if (vsk->close_work_scheduled && @@ -748,8 +748,8 @@ static bool virtio_transport_close(struct vsock_sock *vsk) { struct sock *sk = &vsk->sk; - if (!(sk->sk_state == SS_CONNECTED || - sk->sk_state == SS_DISCONNECTING)) + if (!(sk->sk_state == TCP_ESTABLISHED || + sk->sk_state == TCP_CLOSING)) return true; /* Already received SHUTDOWN from peer, reply with RST */ @@ -801,7 +801,7 @@ virtio_transport_recv_connecting(struct sock *sk, switch (le16_to_cpu(pkt->hdr.op)) { case VIRTIO_VSOCK_OP_RESPONSE: - sk->sk_state = SS_CONNECTED; + sk->sk_state = TCP_ESTABLISHED; sk->sk_socket->state = SS_CONNECTED; vsock_insert_connected(vsk); sk->sk_state_change(sk); @@ -821,7 +821,7 @@ virtio_transport_recv_connecting(struct sock *sk, destroy: virtio_transport_reset(vsk, pkt); - sk->sk_state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; sk->sk_err = skerr; sk->sk_error_report(sk); return err; @@ -857,7 +857,7 @@ virtio_transport_recv_connected(struct sock *sk, vsk->peer_shutdown |= SEND_SHUTDOWN; if (vsk->peer_shutdown == SHUTDOWN_MASK && vsock_stream_has_data(vsk) <= 0) - sk->sk_state = SS_DISCONNECTING; + sk->sk_state = TCP_CLOSING; if (le32_to_cpu(pkt->hdr.flags)) sk->sk_state_change(sk); break; @@ -928,7 +928,7 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) lock_sock_nested(child, SINGLE_DEPTH_NESTING); - child->sk_state = SS_CONNECTED; + child->sk_state = TCP_ESTABLISHED; vchild = vsock_sk(child); vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid), @@ -1016,18 +1016,18 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) sk->sk_write_space(sk); switch (sk->sk_state) { - case VSOCK_SS_LISTEN: + case TCP_LISTEN: virtio_transport_recv_listen(sk, pkt); virtio_transport_free_pkt(pkt); break; - case SS_CONNECTING: + case TCP_SYN_SENT: virtio_transport_recv_connecting(sk, pkt); virtio_transport_free_pkt(pkt); break; - case SS_CONNECTED: + case TCP_ESTABLISHED: virtio_transport_recv_connected(sk, pkt); break; - case SS_DISCONNECTING: + case TCP_CLOSING: virtio_transport_recv_disconnecting(sk, pkt); virtio_transport_free_pkt(pkt); break; diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index 0206155bff53530e3879b404bc86d74c3667d6cb..391775e3575c24a81ae70f75a451645c0b734b73 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -742,7 +742,7 @@ static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg) /* The local context ID may be out of date, update it. */ vsk->local_addr.svm_cid = dst.svm_cid; - if (sk->sk_state == SS_CONNECTED) + if (sk->sk_state == TCP_ESTABLISHED) vmci_trans(vsk)->notify_ops->handle_notify_pkt( sk, pkt, true, &dst, &src, &bh_process_pkt); @@ -800,7 +800,9 @@ static void vmci_transport_handle_detach(struct sock *sk) * left in our consume queue. */ if (vsock_stream_has_data(vsk) <= 0) { - if (sk->sk_state == SS_CONNECTING) { + sk->sk_state = TCP_CLOSE; + + if (sk->sk_state == TCP_SYN_SENT) { /* The peer may detach from a queue pair while * we are still in the connecting state, i.e., * if the peer VM is killed after attaching to @@ -809,12 +811,10 @@ static void vmci_transport_handle_detach(struct sock *sk) * event like a reset. */ - sk->sk_state = SS_UNCONNECTED; sk->sk_err = ECONNRESET; sk->sk_error_report(sk); return; } - sk->sk_state = SS_UNCONNECTED; } sk->sk_state_change(sk); } @@ -882,17 +882,17 @@ static void vmci_transport_recv_pkt_work(struct work_struct *work) vsock_sk(sk)->local_addr.svm_cid = pkt->dg.dst.context; switch (sk->sk_state) { - case VSOCK_SS_LISTEN: + case TCP_LISTEN: vmci_transport_recv_listen(sk, pkt); break; - case SS_CONNECTING: + case TCP_SYN_SENT: /* Processing of pending connections for servers goes through * the listening socket, so see vmci_transport_recv_listen() * for that path. */ vmci_transport_recv_connecting_client(sk, pkt); break; - case SS_CONNECTED: + case TCP_ESTABLISHED: vmci_transport_recv_connected(sk, pkt); break; default: @@ -941,7 +941,7 @@ static int vmci_transport_recv_listen(struct sock *sk, vsock_sk(pending)->local_addr.svm_cid = pkt->dg.dst.context; switch (pending->sk_state) { - case SS_CONNECTING: + case TCP_SYN_SENT: err = vmci_transport_recv_connecting_server(sk, pending, pkt); @@ -1071,7 +1071,7 @@ static int vmci_transport_recv_listen(struct sock *sk, vsock_add_pending(sk, pending); sk->sk_ack_backlog++; - pending->sk_state = SS_CONNECTING; + pending->sk_state = TCP_SYN_SENT; vmci_trans(vpending)->produce_size = vmci_trans(vpending)->consume_size = qp_size; vmci_trans(vpending)->queue_pair_size = qp_size; @@ -1196,11 +1196,11 @@ vmci_transport_recv_connecting_server(struct sock *listener, * the socket will be valid until it is removed from the queue. * * If we fail sending the attach below, we remove the socket from the - * connected list and move the socket to SS_UNCONNECTED before + * connected list and move the socket to TCP_CLOSE before * releasing the lock, so a pending slow path processing of an incoming * packet will not see the socket in the connected state in that case. */ - pending->sk_state = SS_CONNECTED; + pending->sk_state = TCP_ESTABLISHED; vsock_insert_connected(vpending); @@ -1231,7 +1231,7 @@ vmci_transport_recv_connecting_server(struct sock *listener, destroy: pending->sk_err = skerr; - pending->sk_state = SS_UNCONNECTED; + pending->sk_state = TCP_CLOSE; /* As long as we drop our reference, all necessary cleanup will handle * when the cleanup function drops its reference and our destruct * implementation is called. Note that since the listen handler will @@ -1269,7 +1269,7 @@ vmci_transport_recv_connecting_client(struct sock *sk, * accounting (it can already be found since it's in the bound * table). */ - sk->sk_state = SS_CONNECTED; + sk->sk_state = TCP_ESTABLISHED; sk->sk_socket->state = SS_CONNECTED; vsock_insert_connected(vsk); sk->sk_state_change(sk); @@ -1337,7 +1337,7 @@ vmci_transport_recv_connecting_client(struct sock *sk, destroy: vmci_transport_send_reset(sk, pkt); - sk->sk_state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; sk->sk_err = skerr; sk->sk_error_report(sk); return err; @@ -1525,7 +1525,7 @@ static int vmci_transport_recv_connected(struct sock *sk, sock_set_flag(sk, SOCK_DONE); vsk->peer_shutdown = SHUTDOWN_MASK; if (vsock_stream_has_data(vsk) <= 0) - sk->sk_state = SS_DISCONNECTING; + sk->sk_state = TCP_CLOSING; sk->sk_state_change(sk); break; @@ -1789,7 +1789,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk) err = vmci_transport_send_conn_request( sk, vmci_trans(vsk)->queue_pair_size); if (err < 0) { - sk->sk_state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; return err; } } else { @@ -1799,7 +1799,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk) sk, vmci_trans(vsk)->queue_pair_size, supported_proto_versions); if (err < 0) { - sk->sk_state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; return err; } diff --git a/net/vmw_vsock/vmci_transport_notify.c b/net/vmw_vsock/vmci_transport_notify.c index 1406db4d97d14fe2d204f9b93ecfbdb53ec3e34f..41fb427f150a2880cf7534738ccacf0c1339da77 100644 --- a/net/vmw_vsock/vmci_transport_notify.c +++ b/net/vmw_vsock/vmci_transport_notify.c @@ -355,7 +355,7 @@ vmci_transport_notify_pkt_poll_in(struct sock *sk, * queue. Ask for notifications when there is something to * read. */ - if (sk->sk_state == SS_CONNECTED) { + if (sk->sk_state == TCP_ESTABLISHED) { if (!send_waiting_read(sk, 1)) return -1; diff --git a/net/vmw_vsock/vmci_transport_notify_qstate.c b/net/vmw_vsock/vmci_transport_notify_qstate.c index f3a0afc46208137a84227cc60d1d8fe9da8d7ec0..0cc84f2bb05e500956e4a61ba4188c0a05e7d8d8 100644 --- a/net/vmw_vsock/vmci_transport_notify_qstate.c +++ b/net/vmw_vsock/vmci_transport_notify_qstate.c @@ -176,7 +176,7 @@ vmci_transport_notify_pkt_poll_in(struct sock *sk, * queue. Ask for notifications when there is something to * read. */ - if (sk->sk_state == SS_CONNECTED) + if (sk->sk_state == TCP_ESTABLISHED) vsock_block_update_write_window(sk); *data_ready_now = false; } diff --git a/tools/testing/vsock/.gitignore b/tools/testing/vsock/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..dc5f11faf5300724395ceed6a7d30014c6a112e4 --- /dev/null +++ b/tools/testing/vsock/.gitignore @@ -0,0 +1,2 @@ +*.d +vsock_diag_test diff --git a/tools/testing/vsock/Makefile b/tools/testing/vsock/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..66ba0924194df544101cedf3c1c72a6d8809a211 --- /dev/null +++ b/tools/testing/vsock/Makefile @@ -0,0 +1,9 @@ +all: test +test: vsock_diag_test +vsock_diag_test: vsock_diag_test.o timeout.o control.o + +CFLAGS += -g -O2 -Werror -Wall -I. -I../../include/uapi -I../../include -Wno-pointer-sign -fno-strict-overflow -fno-strict-aliasing -fno-common -MMD -U_FORTIFY_SOURCE -D_GNU_SOURCE +.PHONY: all test clean +clean: + ${RM} *.o *.d vsock_diag_test +-include *.d diff --git a/tools/testing/vsock/README b/tools/testing/vsock/README new file mode 100644 index 0000000000000000000000000000000000000000..2cc6d7302db64778d5d08d72680f2965c283d8d5 --- /dev/null +++ b/tools/testing/vsock/README @@ -0,0 +1,36 @@ +AF_VSOCK test suite +------------------- +These tests exercise net/vmw_vsock/ host<->guest sockets for VMware, KVM, and +Hyper-V. + +The following tests are available: + + * vsock_diag_test - vsock_diag.ko module for listing open sockets + +The following prerequisite steps are not automated and must be performed prior +to running tests: + +1. Build the kernel and these tests. +2. Install the kernel and tests on the host. +3. Install the kernel and tests inside the guest. +4. Boot the guest and ensure that the AF_VSOCK transport is enabled. + +Invoke test binaries in both directions as follows: + + # host=server, guest=client + (host)# $TEST_BINARY --mode=server \ + --control-port=1234 \ + --peer-cid=3 + (guest)# $TEST_BINARY --mode=client \ + --control-host=$HOST_IP \ + --control-port=1234 \ + --peer-cid=2 + + # host=client, guest=server + (guest)# $TEST_BINARY --mode=server \ + --control-port=1234 \ + --peer-cid=2 + (host)# $TEST_BINARY --mode=client \ + --control-port=$GUEST_IP \ + --control-port=1234 \ + --peer-cid=3 diff --git a/tools/testing/vsock/control.c b/tools/testing/vsock/control.c new file mode 100644 index 0000000000000000000000000000000000000000..90fd47f0e4227d58391b576538960b9f3f1b85b4 --- /dev/null +++ b/tools/testing/vsock/control.c @@ -0,0 +1,219 @@ +/* Control socket for client/server test execution + * + * Copyright (C) 2017 Red Hat, Inc. + * + * Author: Stefan Hajnoczi + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * as published by the Free Software Foundation; version 2 + * of the License. + */ + +/* The client and server may need to coordinate to avoid race conditions like + * the client attempting to connect to a socket that the server is not + * listening on yet. The control socket offers a communications channel for + * such coordination tasks. + * + * If the client calls control_expectln("LISTENING"), then it will block until + * the server calls control_writeln("LISTENING"). This provides a simple + * mechanism for coordinating between the client and the server. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "timeout.h" +#include "control.h" + +static int control_fd = -1; + +/* Open the control socket, either in server or client mode */ +void control_init(const char *control_host, + const char *control_port, + bool server) +{ + struct addrinfo hints = { + .ai_socktype = SOCK_STREAM, + }; + struct addrinfo *result = NULL; + struct addrinfo *ai; + int ret; + + ret = getaddrinfo(control_host, control_port, &hints, &result); + if (ret != 0) { + fprintf(stderr, "%s\n", gai_strerror(ret)); + exit(EXIT_FAILURE); + } + + for (ai = result; ai; ai = ai->ai_next) { + int fd; + int val = 1; + + fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); + if (fd < 0) + continue; + + if (!server) { + if (connect(fd, ai->ai_addr, ai->ai_addrlen) < 0) + goto next; + control_fd = fd; + printf("Control socket connected to %s:%s.\n", + control_host, control_port); + break; + } + + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, + &val, sizeof(val)) < 0) { + perror("setsockopt"); + exit(EXIT_FAILURE); + } + + if (bind(fd, ai->ai_addr, ai->ai_addrlen) < 0) + goto next; + if (listen(fd, 1) < 0) + goto next; + + printf("Control socket listening on %s:%s\n", + control_host, control_port); + fflush(stdout); + + control_fd = accept(fd, NULL, 0); + close(fd); + + if (control_fd < 0) { + perror("accept"); + exit(EXIT_FAILURE); + } + printf("Control socket connection accepted...\n"); + break; + +next: + close(fd); + } + + if (control_fd < 0) { + fprintf(stderr, "Control socket initialization failed. Invalid address %s:%s?\n", + control_host, control_port); + exit(EXIT_FAILURE); + } + + freeaddrinfo(result); +} + +/* Free resources */ +void control_cleanup(void) +{ + close(control_fd); + control_fd = -1; +} + +/* Write a line to the control socket */ +void control_writeln(const char *str) +{ + ssize_t len = strlen(str); + ssize_t ret; + + timeout_begin(TIMEOUT); + + do { + ret = send(control_fd, str, len, MSG_MORE); + timeout_check("send"); + } while (ret < 0 && errno == EINTR); + + if (ret != len) { + perror("send"); + exit(EXIT_FAILURE); + } + + do { + ret = send(control_fd, "\n", 1, 0); + timeout_check("send"); + } while (ret < 0 && errno == EINTR); + + if (ret != 1) { + perror("send"); + exit(EXIT_FAILURE); + } + + timeout_end(); +} + +/* Return the next line from the control socket (without the trailing newline). + * + * The program terminates if a timeout occurs. + * + * The caller must free() the returned string. + */ +char *control_readln(void) +{ + char *buf = NULL; + size_t idx = 0; + size_t buflen = 0; + + timeout_begin(TIMEOUT); + + for (;;) { + ssize_t ret; + + if (idx >= buflen) { + char *new_buf; + + new_buf = realloc(buf, buflen + 80); + if (!new_buf) { + perror("realloc"); + exit(EXIT_FAILURE); + } + + buf = new_buf; + buflen += 80; + } + + do { + ret = recv(control_fd, &buf[idx], 1, 0); + timeout_check("recv"); + } while (ret < 0 && errno == EINTR); + + if (ret == 0) { + fprintf(stderr, "unexpected EOF on control socket\n"); + exit(EXIT_FAILURE); + } + + if (ret != 1) { + perror("recv"); + exit(EXIT_FAILURE); + } + + if (buf[idx] == '\n') { + buf[idx] = '\0'; + break; + } + + idx++; + } + + timeout_end(); + + return buf; +} + +/* Wait until a given line is received or a timeout occurs */ +void control_expectln(const char *str) +{ + char *line; + + line = control_readln(); + if (strcmp(str, line) != 0) { + fprintf(stderr, "expected \"%s\" on control socket, got \"%s\"\n", + str, line); + exit(EXIT_FAILURE); + } + + free(line); +} diff --git a/tools/testing/vsock/control.h b/tools/testing/vsock/control.h new file mode 100644 index 0000000000000000000000000000000000000000..54a07efd267c2f37cdc83521e524968cf5403894 --- /dev/null +++ b/tools/testing/vsock/control.h @@ -0,0 +1,13 @@ +#ifndef CONTROL_H +#define CONTROL_H + +#include + +void control_init(const char *control_host, const char *control_port, + bool server); +void control_cleanup(void); +void control_writeln(const char *str); +char *control_readln(void); +void control_expectln(const char *str); + +#endif /* CONTROL_H */ diff --git a/tools/testing/vsock/timeout.c b/tools/testing/vsock/timeout.c new file mode 100644 index 0000000000000000000000000000000000000000..c49b3003b2dba671251b94f2c5f1ac41d42c2eb5 --- /dev/null +++ b/tools/testing/vsock/timeout.c @@ -0,0 +1,64 @@ +/* Timeout API for single-threaded programs that use blocking + * syscalls (read/write/send/recv/connect/accept). + * + * Copyright (C) 2017 Red Hat, Inc. + * + * Author: Stefan Hajnoczi + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * as published by the Free Software Foundation; version 2 + * of the License. + */ + +/* Use the following pattern: + * + * timeout_begin(TIMEOUT); + * do { + * ret = accept(...); + * timeout_check("accept"); + * } while (ret < 0 && ret == EINTR); + * timeout_end(); + */ + +#include +#include +#include +#include +#include "timeout.h" + +static volatile bool timeout; + +/* SIGALRM handler function. Do not use sleep(2), alarm(2), or + * setitimer(2) while using this API - they may interfere with each + * other. + */ +void sigalrm(int signo) +{ + timeout = true; +} + +/* Start a timeout. Call timeout_check() to verify that the timeout hasn't + * expired. timeout_end() must be called to stop the timeout. Timeouts cannot + * be nested. + */ +void timeout_begin(unsigned int seconds) +{ + alarm(seconds); +} + +/* Exit with an error message if the timeout has expired */ +void timeout_check(const char *operation) +{ + if (timeout) { + fprintf(stderr, "%s timed out\n", operation); + exit(EXIT_FAILURE); + } +} + +/* Stop a timeout */ +void timeout_end(void) +{ + alarm(0); + timeout = false; +} diff --git a/tools/testing/vsock/timeout.h b/tools/testing/vsock/timeout.h new file mode 100644 index 0000000000000000000000000000000000000000..77db9ce9860a97432de20caee1f67fb499c35ddc --- /dev/null +++ b/tools/testing/vsock/timeout.h @@ -0,0 +1,14 @@ +#ifndef TIMEOUT_H +#define TIMEOUT_H + +enum { + /* Default timeout */ + TIMEOUT = 10 /* seconds */ +}; + +void sigalrm(int signo); +void timeout_begin(unsigned int seconds); +void timeout_check(const char *operation); +void timeout_end(void); + +#endif /* TIMEOUT_H */ diff --git a/tools/testing/vsock/vsock_diag_test.c b/tools/testing/vsock/vsock_diag_test.c new file mode 100644 index 0000000000000000000000000000000000000000..e896a4af52f4025b5905175d84c0d0c1d5aa0c6b --- /dev/null +++ b/tools/testing/vsock/vsock_diag_test.c @@ -0,0 +1,681 @@ +/* + * vsock_diag_test - vsock_diag.ko test suite + * + * Copyright (C) 2017 Red Hat, Inc. + * + * Author: Stefan Hajnoczi + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * as published by the Free Software Foundation; version 2 + * of the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../../include/uapi/linux/vm_sockets.h" +#include "../../../include/uapi/linux/vm_sockets_diag.h" + +#include "timeout.h" +#include "control.h" + +enum test_mode { + TEST_MODE_UNSET, + TEST_MODE_CLIENT, + TEST_MODE_SERVER +}; + +/* Per-socket status */ +struct vsock_stat { + struct list_head list; + struct vsock_diag_msg msg; +}; + +static const char *sock_type_str(int type) +{ + switch (type) { + case SOCK_DGRAM: + return "DGRAM"; + case SOCK_STREAM: + return "STREAM"; + default: + return "INVALID TYPE"; + } +} + +static const char *sock_state_str(int state) +{ + switch (state) { + case TCP_CLOSE: + return "UNCONNECTED"; + case TCP_SYN_SENT: + return "CONNECTING"; + case TCP_ESTABLISHED: + return "CONNECTED"; + case TCP_CLOSING: + return "DISCONNECTING"; + case TCP_LISTEN: + return "LISTEN"; + default: + return "INVALID STATE"; + } +} + +static const char *sock_shutdown_str(int shutdown) +{ + switch (shutdown) { + case 1: + return "RCV_SHUTDOWN"; + case 2: + return "SEND_SHUTDOWN"; + case 3: + return "RCV_SHUTDOWN | SEND_SHUTDOWN"; + default: + return "0"; + } +} + +static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port) +{ + if (cid == VMADDR_CID_ANY) + fprintf(fp, "*:"); + else + fprintf(fp, "%u:", cid); + + if (port == VMADDR_PORT_ANY) + fprintf(fp, "*"); + else + fprintf(fp, "%u", port); +} + +static void print_vsock_stat(FILE *fp, struct vsock_stat *st) +{ + print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port); + fprintf(fp, " "); + print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port); + fprintf(fp, " %s %s %s %u\n", + sock_type_str(st->msg.vdiag_type), + sock_state_str(st->msg.vdiag_state), + sock_shutdown_str(st->msg.vdiag_shutdown), + st->msg.vdiag_ino); +} + +static void print_vsock_stats(FILE *fp, struct list_head *head) +{ + struct vsock_stat *st; + + list_for_each_entry(st, head, list) + print_vsock_stat(fp, st); +} + +static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd) +{ + struct vsock_stat *st; + struct stat stat; + + if (fstat(fd, &stat) < 0) { + perror("fstat"); + exit(EXIT_FAILURE); + } + + list_for_each_entry(st, head, list) + if (st->msg.vdiag_ino == stat.st_ino) + return st; + + fprintf(stderr, "cannot find fd %d\n", fd); + exit(EXIT_FAILURE); +} + +static void check_no_sockets(struct list_head *head) +{ + if (!list_empty(head)) { + fprintf(stderr, "expected no sockets\n"); + print_vsock_stats(stderr, head); + exit(1); + } +} + +static void check_num_sockets(struct list_head *head, int expected) +{ + struct list_head *node; + int n = 0; + + list_for_each(node, head) + n++; + + if (n != expected) { + fprintf(stderr, "expected %d sockets, found %d\n", + expected, n); + print_vsock_stats(stderr, head); + exit(EXIT_FAILURE); + } +} + +static void check_socket_state(struct vsock_stat *st, __u8 state) +{ + if (st->msg.vdiag_state != state) { + fprintf(stderr, "expected socket state %#x, got %#x\n", + state, st->msg.vdiag_state); + exit(EXIT_FAILURE); + } +} + +static void send_req(int fd) +{ + struct sockaddr_nl nladdr = { + .nl_family = AF_NETLINK, + }; + struct { + struct nlmsghdr nlh; + struct vsock_diag_req vreq; + } req = { + .nlh = { + .nlmsg_len = sizeof(req), + .nlmsg_type = SOCK_DIAG_BY_FAMILY, + .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP, + }, + .vreq = { + .sdiag_family = AF_VSOCK, + .vdiag_states = ~(__u32)0, + }, + }; + struct iovec iov = { + .iov_base = &req, + .iov_len = sizeof(req), + }; + struct msghdr msg = { + .msg_name = &nladdr, + .msg_namelen = sizeof(nladdr), + .msg_iov = &iov, + .msg_iovlen = 1, + }; + + for (;;) { + if (sendmsg(fd, &msg, 0) < 0) { + if (errno == EINTR) + continue; + + perror("sendmsg"); + exit(EXIT_FAILURE); + } + + return; + } +} + +static ssize_t recv_resp(int fd, void *buf, size_t len) +{ + struct sockaddr_nl nladdr = { + .nl_family = AF_NETLINK, + }; + struct iovec iov = { + .iov_base = buf, + .iov_len = len, + }; + struct msghdr msg = { + .msg_name = &nladdr, + .msg_namelen = sizeof(nladdr), + .msg_iov = &iov, + .msg_iovlen = 1, + }; + ssize_t ret; + + do { + ret = recvmsg(fd, &msg, 0); + } while (ret < 0 && errno == EINTR); + + if (ret < 0) { + perror("recvmsg"); + exit(EXIT_FAILURE); + } + + return ret; +} + +static void add_vsock_stat(struct list_head *sockets, + const struct vsock_diag_msg *resp) +{ + struct vsock_stat *st; + + st = malloc(sizeof(*st)); + if (!st) { + perror("malloc"); + exit(EXIT_FAILURE); + } + + st->msg = *resp; + list_add_tail(&st->list, sockets); +} + +/* + * Read vsock stats into a list. + */ +static void read_vsock_stat(struct list_head *sockets) +{ + long buf[8192 / sizeof(long)]; + int fd; + + fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG); + if (fd < 0) { + perror("socket"); + exit(EXIT_FAILURE); + } + + send_req(fd); + + for (;;) { + const struct nlmsghdr *h; + ssize_t ret; + + ret = recv_resp(fd, buf, sizeof(buf)); + if (ret == 0) + goto done; + if (ret < sizeof(*h)) { + fprintf(stderr, "short read of %zd bytes\n", ret); + exit(EXIT_FAILURE); + } + + h = (struct nlmsghdr *)buf; + + while (NLMSG_OK(h, ret)) { + if (h->nlmsg_type == NLMSG_DONE) + goto done; + + if (h->nlmsg_type == NLMSG_ERROR) { + const struct nlmsgerr *err = NLMSG_DATA(h); + + if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err))) + fprintf(stderr, "NLMSG_ERROR\n"); + else { + errno = -err->error; + perror("NLMSG_ERROR"); + } + + exit(EXIT_FAILURE); + } + + if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) { + fprintf(stderr, "unexpected nlmsg_type %#x\n", + h->nlmsg_type); + exit(EXIT_FAILURE); + } + if (h->nlmsg_len < + NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) { + fprintf(stderr, "short vsock_diag_msg\n"); + exit(EXIT_FAILURE); + } + + add_vsock_stat(sockets, NLMSG_DATA(h)); + + h = NLMSG_NEXT(h, ret); + } + } + +done: + close(fd); +} + +static void free_sock_stat(struct list_head *sockets) +{ + struct vsock_stat *st; + struct vsock_stat *next; + + list_for_each_entry_safe(st, next, sockets, list) + free(st); +} + +static void test_no_sockets(unsigned int peer_cid) +{ + LIST_HEAD(sockets); + + read_vsock_stat(&sockets); + + check_no_sockets(&sockets); + + free_sock_stat(&sockets); +} + +static void test_listen_socket_server(unsigned int peer_cid) +{ + union { + struct sockaddr sa; + struct sockaddr_vm svm; + } addr = { + .svm = { + .svm_family = AF_VSOCK, + .svm_port = 1234, + .svm_cid = VMADDR_CID_ANY, + }, + }; + LIST_HEAD(sockets); + struct vsock_stat *st; + int fd; + + fd = socket(AF_VSOCK, SOCK_STREAM, 0); + + if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) { + perror("bind"); + exit(EXIT_FAILURE); + } + + if (listen(fd, 1) < 0) { + perror("listen"); + exit(EXIT_FAILURE); + } + + read_vsock_stat(&sockets); + + check_num_sockets(&sockets, 1); + st = find_vsock_stat(&sockets, fd); + check_socket_state(st, TCP_LISTEN); + + close(fd); + free_sock_stat(&sockets); +} + +static void test_connect_client(unsigned int peer_cid) +{ + union { + struct sockaddr sa; + struct sockaddr_vm svm; + } addr = { + .svm = { + .svm_family = AF_VSOCK, + .svm_port = 1234, + .svm_cid = peer_cid, + }, + }; + int fd; + int ret; + LIST_HEAD(sockets); + struct vsock_stat *st; + + control_expectln("LISTENING"); + + fd = socket(AF_VSOCK, SOCK_STREAM, 0); + + timeout_begin(TIMEOUT); + do { + ret = connect(fd, &addr.sa, sizeof(addr.svm)); + timeout_check("connect"); + } while (ret < 0 && errno == EINTR); + timeout_end(); + + if (ret < 0) { + perror("connect"); + exit(EXIT_FAILURE); + } + + read_vsock_stat(&sockets); + + check_num_sockets(&sockets, 1); + st = find_vsock_stat(&sockets, fd); + check_socket_state(st, TCP_ESTABLISHED); + + control_expectln("DONE"); + control_writeln("DONE"); + + close(fd); + free_sock_stat(&sockets); +} + +static void test_connect_server(unsigned int peer_cid) +{ + union { + struct sockaddr sa; + struct sockaddr_vm svm; + } addr = { + .svm = { + .svm_family = AF_VSOCK, + .svm_port = 1234, + .svm_cid = VMADDR_CID_ANY, + }, + }; + union { + struct sockaddr sa; + struct sockaddr_vm svm; + } clientaddr; + socklen_t clientaddr_len = sizeof(clientaddr.svm); + LIST_HEAD(sockets); + struct vsock_stat *st; + int fd; + int client_fd; + + fd = socket(AF_VSOCK, SOCK_STREAM, 0); + + if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) { + perror("bind"); + exit(EXIT_FAILURE); + } + + if (listen(fd, 1) < 0) { + perror("listen"); + exit(EXIT_FAILURE); + } + + control_writeln("LISTENING"); + + timeout_begin(TIMEOUT); + do { + client_fd = accept(fd, &clientaddr.sa, &clientaddr_len); + timeout_check("accept"); + } while (client_fd < 0 && errno == EINTR); + timeout_end(); + + if (client_fd < 0) { + perror("accept"); + exit(EXIT_FAILURE); + } + if (clientaddr.sa.sa_family != AF_VSOCK) { + fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n", + clientaddr.sa.sa_family); + exit(EXIT_FAILURE); + } + if (clientaddr.svm.svm_cid != peer_cid) { + fprintf(stderr, "expected peer CID %u from accept(2), got %u\n", + peer_cid, clientaddr.svm.svm_cid); + exit(EXIT_FAILURE); + } + + read_vsock_stat(&sockets); + + check_num_sockets(&sockets, 2); + find_vsock_stat(&sockets, fd); + st = find_vsock_stat(&sockets, client_fd); + check_socket_state(st, TCP_ESTABLISHED); + + control_writeln("DONE"); + control_expectln("DONE"); + + close(client_fd); + close(fd); + free_sock_stat(&sockets); +} + +static struct { + const char *name; + void (*run_client)(unsigned int peer_cid); + void (*run_server)(unsigned int peer_cid); +} test_cases[] = { + { + .name = "No sockets", + .run_server = test_no_sockets, + }, + { + .name = "Listen socket", + .run_server = test_listen_socket_server, + }, + { + .name = "Connect", + .run_client = test_connect_client, + .run_server = test_connect_server, + }, + {}, +}; + +static void init_signals(void) +{ + struct sigaction act = { + .sa_handler = sigalrm, + }; + + sigaction(SIGALRM, &act, NULL); + signal(SIGPIPE, SIG_IGN); +} + +static unsigned int parse_cid(const char *str) +{ + char *endptr = NULL; + unsigned long int n; + + errno = 0; + n = strtoul(str, &endptr, 10); + if (errno || *endptr != '\0') { + fprintf(stderr, "malformed CID \"%s\"\n", str); + exit(EXIT_FAILURE); + } + return n; +} + +static const char optstring[] = ""; +static const struct option longopts[] = { + { + .name = "control-host", + .has_arg = required_argument, + .val = 'H', + }, + { + .name = "control-port", + .has_arg = required_argument, + .val = 'P', + }, + { + .name = "mode", + .has_arg = required_argument, + .val = 'm', + }, + { + .name = "peer-cid", + .has_arg = required_argument, + .val = 'p', + }, + { + .name = "help", + .has_arg = no_argument, + .val = '?', + }, + {}, +}; + +static void usage(void) +{ + fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=] --control-port= --mode=client|server --peer-cid=\n" + "\n" + " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n" + " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n" + "\n" + "Run vsock_diag.ko tests. Must be launched in both\n" + "guest and host. One side must use --mode=client and\n" + "the other side must use --mode=server.\n" + "\n" + "A TCP control socket connection is used to coordinate tests\n" + "between the client and the server. The server requires a\n" + "listen address and the client requires an address to\n" + "connect to.\n" + "\n" + "The CID of the other side must be given with --peer-cid=.\n"); + exit(EXIT_FAILURE); +} + +int main(int argc, char **argv) +{ + const char *control_host = NULL; + const char *control_port = NULL; + int mode = TEST_MODE_UNSET; + unsigned int peer_cid = VMADDR_CID_ANY; + int i; + + init_signals(); + + for (;;) { + int opt = getopt_long(argc, argv, optstring, longopts, NULL); + + if (opt == -1) + break; + + switch (opt) { + case 'H': + control_host = optarg; + break; + case 'm': + if (strcmp(optarg, "client") == 0) + mode = TEST_MODE_CLIENT; + else if (strcmp(optarg, "server") == 0) + mode = TEST_MODE_SERVER; + else { + fprintf(stderr, "--mode must be \"client\" or \"server\"\n"); + return EXIT_FAILURE; + } + break; + case 'p': + peer_cid = parse_cid(optarg); + break; + case 'P': + control_port = optarg; + break; + case '?': + default: + usage(); + } + } + + if (!control_port) + usage(); + if (mode == TEST_MODE_UNSET) + usage(); + if (peer_cid == VMADDR_CID_ANY) + usage(); + + if (!control_host) { + if (mode != TEST_MODE_SERVER) + usage(); + control_host = "0.0.0.0"; + } + + control_init(control_host, control_port, mode == TEST_MODE_SERVER); + + for (i = 0; test_cases[i].name; i++) { + void (*run)(unsigned int peer_cid); + + printf("%s...", test_cases[i].name); + fflush(stdout); + + if (mode == TEST_MODE_CLIENT) + run = test_cases[i].run_client; + else + run = test_cases[i].run_server; + + if (run) + run(peer_cid); + + printf("ok\n"); + } + + control_cleanup(); + return EXIT_SUCCESS; +}