diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c index b235f4bbe8ea1d1d476980f055fa8921b4a03ad2..fdda9ec625adc5f98e7dfa2536c6f8a60bec41c8 100644 --- a/drivers/vhost/vsock.c +++ b/drivers/vhost/vsock.c @@ -386,6 +386,8 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock) static struct virtio_transport vhost_transport = { .transport = { + .module = THIS_MODULE, + .get_local_cid = vhost_transport_get_local_cid, .init = virtio_transport_do_socket_init, diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h index cf5c3691251bd5c45f1fb3bd71d78579d65807f5..4206dc6d813f21eb04cad1a673fc1d59174b6054 100644 --- a/include/net/af_vsock.h +++ b/include/net/af_vsock.h @@ -100,6 +100,8 @@ struct vsock_transport_send_notify_data { #define VSOCK_TRANSPORT_F_DGRAM 0x00000004 struct vsock_transport { + struct module *module; + /* Initialize/tear-down socket. */ int (*init)(struct vsock_sock *, struct vsock_sock *); void (*destruct)(struct vsock_sock *); diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 5357714b610419463f34e64e57cf45dbbc98f38a..5cb0ae42d91608e1f9a038cea09cf1c61728d362 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -380,6 +380,16 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected) } EXPORT_SYMBOL_GPL(vsock_enqueue_accept); +static void vsock_deassign_transport(struct vsock_sock *vsk) +{ + if (!vsk->transport) + return; + + vsk->transport->destruct(vsk); + module_put(vsk->transport->module); + vsk->transport = NULL; +} + /* Assign a transport to a socket and call the .init transport callback. * * Note: for stream socket this must be called when vsk->remote_addr is set @@ -418,10 +428,13 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) return 0; vsk->transport->release(vsk); - vsk->transport->destruct(vsk); + vsock_deassign_transport(vsk); } - if (!new_transport) + /* We increase the module refcnt to prevent the transport unloading + * while there are open sockets assigned to it. + */ + if (!new_transport || !try_module_get(new_transport->module)) return -ENODEV; vsk->transport = new_transport; @@ -741,8 +754,7 @@ static void vsock_sk_destruct(struct sock *sk) { struct vsock_sock *vsk = vsock_sk(sk); - if (vsk->transport) - vsk->transport->destruct(vsk); + vsock_deassign_transport(vsk); /* When clearing these addresses, there's no need to set the family and * possibly register the address family with the kernel. diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index 1c9e65d7d94da9b21e63940cf3b7e1e8ca5def77..3c7d07a99fc5436cf4a78a650292c7a1a0154ff1 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -857,6 +857,8 @@ int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written, } static struct vsock_transport hvs_transport = { + .module = THIS_MODULE, + .get_local_cid = hvs_get_local_cid, .init = hvs_sock_init, diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index 83ad85050384ca7f693418ea4416f24473444e18..1458c5c8b64d2784d79f8def7d709e734dfa13bf 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -462,6 +462,8 @@ static void virtio_vsock_rx_done(struct virtqueue *vq) static struct virtio_transport virtio_transport = { .transport = { + .module = THIS_MODULE, + .get_local_cid = virtio_transport_get_local_cid, .init = virtio_transport_do_socket_init, diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index d9c9c834ad6f1c508c1ceff4f88a7221f53d0a72..644d32e43d230a84963e5ea2fa68dbddfdfe4aee 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -2020,6 +2020,7 @@ static u32 vmci_transport_get_local_cid(void) } static struct vsock_transport vmci_transport = { + .module = THIS_MODULE, .init = vmci_transport_socket_init, .destruct = vmci_transport_destruct, .release = vmci_transport_release,