diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index 2ff8c7caf74fce9c0c7c37ff88eed3519362e40d..81faeff8f3bb7e00991af5ee30179252de103de1 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -2642,11 +2642,12 @@ static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
 
 static int mptcp_disconnect(struct sock *sk, int flags)
 {
-	/* Should never be called.
-	 * inet_stream_connect() calls ->disconnect, but that
-	 * refers to the subflow socket, not the mptcp one.
-	 */
-	WARN_ON_ONCE(1);
+	struct mptcp_subflow_context *subflow;
+	struct mptcp_sock *msk = mptcp_sk(sk);
+
+	__mptcp_flush_join_list(msk);
+	mptcp_for_each_subflow(msk, subflow)
+		tcp_disconnect(mptcp_subflow_tcp_sock(subflow), flags);
 	return 0;
 }
 
@@ -3089,6 +3090,14 @@ bool mptcp_finish_join(struct sock *ssk)
 	return true;
 }
 
+static void mptcp_shutdown(struct sock *sk, int how)
+{
+	pr_debug("sk=%p, how=%d", sk, how);
+
+	if ((how & SEND_SHUTDOWN) && mptcp_close_state(sk))
+		__mptcp_wr_shutdown(sk);
+}
+
 static struct proto mptcp_prot = {
 	.name		= "MPTCP",
 	.owner		= THIS_MODULE,
@@ -3098,7 +3107,7 @@ static struct proto mptcp_prot = {
 	.accept		= mptcp_accept,
 	.setsockopt	= mptcp_setsockopt,
 	.getsockopt	= mptcp_getsockopt,
-	.shutdown	= tcp_shutdown,
+	.shutdown	= mptcp_shutdown,
 	.destroy	= mptcp_destroy,
 	.sendmsg	= mptcp_sendmsg,
 	.recvmsg	= mptcp_recvmsg,
@@ -3344,43 +3353,6 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
 	return mask;
 }
 
-static int mptcp_shutdown(struct socket *sock, int how)
-{
-	struct mptcp_sock *msk = mptcp_sk(sock->sk);
-	struct sock *sk = sock->sk;
-	int ret = 0;
-
-	pr_debug("sk=%p, how=%d", msk, how);
-
-	lock_sock(sk);
-
-	how++;
-	if ((how & ~SHUTDOWN_MASK) || !how) {
-		ret = -EINVAL;
-		goto out_unlock;
-	}
-
-	if (sock->state == SS_CONNECTING) {
-		if ((1 << sk->sk_state) &
-		    (TCPF_SYN_SENT | TCPF_SYN_RECV | TCPF_CLOSE))
-			sock->state = SS_DISCONNECTING;
-		else
-			sock->state = SS_CONNECTED;
-	}
-
-	sk->sk_shutdown |= how;
-	if ((how & SEND_SHUTDOWN) && mptcp_close_state(sk))
-		__mptcp_wr_shutdown(sk);
-
-	/* Wake up anyone sleeping in poll. */
-	sk->sk_state_change(sk);
-
-out_unlock:
-	release_sock(sk);
-
-	return ret;
-}
-
 static const struct proto_ops mptcp_stream_ops = {
 	.family		   = PF_INET,
 	.owner		   = THIS_MODULE,
@@ -3394,7 +3366,7 @@ static const struct proto_ops mptcp_stream_ops = {
 	.ioctl		   = inet_ioctl,
 	.gettstamp	   = sock_gettstamp,
 	.listen		   = mptcp_listen,
-	.shutdown	   = mptcp_shutdown,
+	.shutdown	   = inet_shutdown,
 	.setsockopt	   = sock_common_setsockopt,
 	.getsockopt	   = sock_common_getsockopt,
 	.sendmsg	   = inet_sendmsg,
@@ -3444,7 +3416,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {
 	.ioctl		   = inet6_ioctl,
 	.gettstamp	   = sock_gettstamp,
 	.listen		   = mptcp_listen,
-	.shutdown	   = mptcp_shutdown,
+	.shutdown	   = inet_shutdown,
 	.setsockopt	   = sock_common_setsockopt,
 	.getsockopt	   = sock_common_getsockopt,
 	.sendmsg	   = inet6_sendmsg,