diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index bbe65dcb973834761c30cef249a537ba498898fb..3dce53ebea9240fd2f46f9c428170a2e6c345d7b 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -1209,10 +1209,14 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, if (signal_pending(current)) { err = sock_intr_errno(timeout); - goto out_wait_error; + sk->sk_state = SS_UNCONNECTED; + sock->state = SS_UNCONNECTED; + goto out_wait; } else if (timeout == 0) { err = -ETIMEDOUT; - goto out_wait_error; + sk->sk_state = SS_UNCONNECTED; + sock->state = SS_UNCONNECTED; + goto out_wait; } prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); @@ -1220,20 +1224,17 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, if (sk->sk_err) { err = -sk->sk_err; - goto out_wait_error; - } else + sk->sk_state = SS_UNCONNECTED; + sock->state = SS_UNCONNECTED; + } else { err = 0; + } out_wait: finish_wait(sk_sleep(sk), &wait); out: release_sock(sk); return err; - -out_wait_error: - sk->sk_state = SS_UNCONNECTED; - sock->state = SS_UNCONNECTED; - goto out_wait; } static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) @@ -1270,18 +1271,20 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) listener->sk_err == 0) { release_sock(listener); timeout = schedule_timeout(timeout); + finish_wait(sk_sleep(listener), &wait); lock_sock(listener); if (signal_pending(current)) { err = sock_intr_errno(timeout); - goto out_wait; + goto out; } else if (timeout == 0) { err = -EAGAIN; - goto out_wait; + goto out; } prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); } + finish_wait(sk_sleep(listener), &wait); if (listener->sk_err) err = -listener->sk_err; @@ -1301,19 +1304,15 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) */ if (err) { vconnected->rejected = true; - release_sock(connected); - sock_put(connected); - goto out_wait; + } else { + newsock->state = SS_CONNECTED; + sock_graft(connected, newsock); } - newsock->state = SS_CONNECTED; - sock_graft(connected, newsock); release_sock(connected); sock_put(connected); } -out_wait: - finish_wait(sk_sleep(listener), &wait); out: release_sock(listener); return err; @@ -1557,9 +1556,11 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, if (err < 0) goto out; + while (total_written < len) { ssize_t written; + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); while (vsock_stream_has_space(vsk) == 0 && sk->sk_err == 0 && !(sk->sk_shutdown & SEND_SHUTDOWN) && @@ -1568,27 +1569,33 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, /* Don't wait for non-blocking sockets. */ if (timeout == 0) { err = -EAGAIN; - goto out_wait; + finish_wait(sk_sleep(sk), &wait); + goto out_err; } err = transport->notify_send_pre_block(vsk, &send_data); - if (err < 0) - goto out_wait; + if (err < 0) { + finish_wait(sk_sleep(sk), &wait); + goto out_err; + } release_sock(sk); - prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); timeout = schedule_timeout(timeout); - finish_wait(sk_sleep(sk), &wait); lock_sock(sk); if (signal_pending(current)) { err = sock_intr_errno(timeout); - goto out_wait; + finish_wait(sk_sleep(sk), &wait); + goto out_err; } else if (timeout == 0) { err = -EAGAIN; - goto out_wait; + finish_wait(sk_sleep(sk), &wait); + goto out_err; } + prepare_to_wait(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); } + finish_wait(sk_sleep(sk), &wait); /* These checks occur both as part of and after the loop * conditional since we need to check before and after @@ -1596,16 +1603,16 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, */ if (sk->sk_err) { err = -sk->sk_err; - goto out_wait; + goto out_err; } else if ((sk->sk_shutdown & SEND_SHUTDOWN) || (vsk->peer_shutdown & RCV_SHUTDOWN)) { err = -EPIPE; - goto out_wait; + goto out_err; } err = transport->notify_send_pre_enqueue(vsk, &send_data); if (err < 0) - goto out_wait; + goto out_err; /* Note that enqueue will only write as many bytes as are free * in the produce queue, so we don't need to ensure len is @@ -1618,7 +1625,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, len - total_written); if (written < 0) { err = -ENOMEM; - goto out_wait; + goto out_err; } total_written += written; @@ -1626,11 +1633,11 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, err = transport->notify_send_post_enqueue( vsk, written, &send_data); if (err < 0) - goto out_wait; + goto out_err; } -out_wait: +out_err: if (total_written > 0) err = total_written; out: @@ -1715,18 +1722,59 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, while (1) { - s64 ready = vsock_stream_has_data(vsk); + s64 ready; - if (ready < 0) { - /* Invalid queue pair content. XXX This should be - * changed to a connection reset in a later change. - */ + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + ready = vsock_stream_has_data(vsk); - err = -ENOMEM; - goto out; - } else if (ready > 0) { + if (ready == 0) { + if (sk->sk_err != 0 || + (sk->sk_shutdown & RCV_SHUTDOWN) || + (vsk->peer_shutdown & SEND_SHUTDOWN)) { + finish_wait(sk_sleep(sk), &wait); + break; + } + /* Don't wait for non-blocking sockets. */ + if (timeout == 0) { + err = -EAGAIN; + finish_wait(sk_sleep(sk), &wait); + break; + } + + err = transport->notify_recv_pre_block( + vsk, target, &recv_data); + if (err < 0) { + finish_wait(sk_sleep(sk), &wait); + break; + } + release_sock(sk); + timeout = schedule_timeout(timeout); + lock_sock(sk); + + if (signal_pending(current)) { + err = sock_intr_errno(timeout); + finish_wait(sk_sleep(sk), &wait); + break; + } else if (timeout == 0) { + err = -EAGAIN; + finish_wait(sk_sleep(sk), &wait); + break; + } + } else { ssize_t read; + finish_wait(sk_sleep(sk), &wait); + + if (ready < 0) { + /* Invalid queue pair content. XXX This should + * be changed to a connection reset in a later + * change. + */ + + err = -ENOMEM; + goto out; + } + err = transport->notify_recv_pre_dequeue( vsk, target, &recv_data); if (err < 0) @@ -1752,35 +1800,6 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, break; target -= read; - } else { - if (sk->sk_err != 0 || (sk->sk_shutdown & RCV_SHUTDOWN) - || (vsk->peer_shutdown & SEND_SHUTDOWN)) { - break; - } - /* Don't wait for non-blocking sockets. */ - if (timeout == 0) { - err = -EAGAIN; - break; - } - - err = transport->notify_recv_pre_block( - vsk, target, &recv_data); - if (err < 0) - break; - - release_sock(sk); - prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); - timeout = schedule_timeout(timeout); - finish_wait(sk_sleep(sk), &wait); - lock_sock(sk); - - if (signal_pending(current)) { - err = sock_intr_errno(timeout); - break; - } else if (timeout == 0) { - err = -EAGAIN; - break; - } } }