From 76039ac9095f5ee5ec7fb95ccb6a5460d5f8c3a2 Mon Sep 17 00:00:00 2001
From: Mark Zhang <markzhang@nvidia.com>
Date: Wed, 2 Jun 2021 13:27:08 +0300
Subject: [PATCH] IB/cm: Protect cm_dev, cm_ports and mad_agent with kref and
 lock

During cm_dev deregistration in cm_remove_one(), the cm_device and
cm_ports will be freed, after that they should not be accessed. The
mad_agent needs to be protected as well.

This patch adds a cm_device kref to protect cm_dev and cm_ports, and a
mad_agent_lock spinlock to protect mad_agent.

Link: https://lore.kernel.org/r/501ba7a2ff203dccd0e6755d3f93329772adce52.1622629024.git.leonro@nvidia.com
Signed-off-by: Mark Zhang <markzhang@nvidia.com>
Signed-off-by: Leon Romanovsky <leonro@nvidia.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
---
 drivers/infiniband/core/cm.c | 112 +++++++++++++++++++++++++++++------
 1 file changed, 93 insertions(+), 19 deletions(-)

diff --git a/drivers/infiniband/core/cm.c b/drivers/infiniband/core/cm.c
index 2e4658795e8e..80087e678030 100644
--- a/drivers/infiniband/core/cm.c
+++ b/drivers/infiniband/core/cm.c
@@ -205,7 +205,9 @@ struct cm_port {
 };
 
 struct cm_device {
+	struct kref kref;
 	struct list_head list;
+	spinlock_t mad_agent_lock;
 	struct ib_device *ib_device;
 	u8 ack_delay;
 	int going_down;
@@ -287,6 +289,22 @@ struct cm_id_private {
 	struct rdma_ucm_ece ece;
 };
 
+static void cm_dev_release(struct kref *kref)
+{
+	struct cm_device *cm_dev = container_of(kref, struct cm_device, kref);
+	u32 i;
+
+	rdma_for_each_port(cm_dev->ib_device, i)
+		kfree(cm_dev->port[i - 1]);
+
+	kfree(cm_dev);
+}
+
+static void cm_device_put(struct cm_device *cm_dev)
+{
+	kref_put(&cm_dev->kref, cm_dev_release);
+}
+
 static void cm_work_handler(struct work_struct *work);
 
 static inline void cm_deref_id(struct cm_id_private *cm_id_priv)
@@ -301,10 +319,23 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv)
 	struct ib_mad_send_buf *m;
 	struct ib_ah *ah;
 
+	lockdep_assert_held(&cm_id_priv->lock);
+
+	if (!cm_id_priv->av.port)
+		return ERR_PTR(-EINVAL);
+
+	spin_lock(&cm_id_priv->av.port->cm_dev->mad_agent_lock);
 	mad_agent = cm_id_priv->av.port->mad_agent;
+	if (!mad_agent) {
+		m = ERR_PTR(-EINVAL);
+		goto out;
+	}
+
 	ah = rdma_create_ah(mad_agent->qp->pd, &cm_id_priv->av.ah_attr, 0);
-	if (IS_ERR(ah))
-		return (void *)ah;
+	if (IS_ERR(ah)) {
+		m = ERR_CAST(ah);
+		goto out;
+	}
 
 	m = ib_create_send_mad(mad_agent, cm_id_priv->id.remote_cm_qpn,
 			       cm_id_priv->av.pkey_index,
@@ -313,7 +344,7 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv)
 			       IB_MGMT_BASE_VERSION);
 	if (IS_ERR(m)) {
 		rdma_destroy_ah(ah, 0);
-		return m;
+		goto out;
 	}
 
 	/* Timeout set by caller if response is expected. */
@@ -322,6 +353,9 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv)
 
 	refcount_inc(&cm_id_priv->refcount);
 	m->context[0] = cm_id_priv;
+
+out:
+	spin_unlock(&cm_id_priv->av.port->cm_dev->mad_agent_lock);
 	return m;
 }
 
@@ -440,10 +474,24 @@ static void cm_set_private_data(struct cm_id_private *cm_id_priv,
 	cm_id_priv->private_data_len = private_data_len;
 }
 
+static void cm_set_av_port(struct cm_av *av, struct cm_port *port)
+{
+	struct cm_port *old_port = av->port;
+
+	if (old_port == port)
+		return;
+
+	av->port = port;
+	if (old_port)
+		cm_device_put(old_port->cm_dev);
+	if (port)
+		kref_get(&port->cm_dev->kref);
+}
+
 static void cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc,
 			       struct rdma_ah_attr *ah_attr, struct cm_av *av)
 {
-	av->port = port;
+	cm_set_av_port(av, port);
 	av->pkey_index = wc->pkey_index;
 	rdma_move_ah_attr(&av->ah_attr, ah_attr);
 }
@@ -451,7 +499,7 @@ static void cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc,
 static int cm_init_av_for_response(struct cm_port *port, struct ib_wc *wc,
 				   struct ib_grh *grh, struct cm_av *av)
 {
-	av->port = port;
+	cm_set_av_port(av, port);
 	av->pkey_index = wc->pkey_index;
 	return ib_init_ah_attr_from_wc(port->cm_dev->ib_device,
 				       port->port_num, wc,
@@ -518,7 +566,7 @@ static int cm_init_av_by_path(struct sa_path_rec *path,
 	if (ret)
 		return ret;
 
-	av->port = port;
+	cm_set_av_port(av, port);
 
 	/*
 	 * av->ah_attr might be initialized based on wc or during
@@ -542,7 +590,8 @@ static int cm_init_av_by_path(struct sa_path_rec *path,
 /* Move av created by cm_init_av_by_path(), so av.dgid is not moved */
 static void cm_move_av_from_path(struct cm_av *dest, struct cm_av *src)
 {
-	dest->port = src->port;
+	cm_set_av_port(dest, src->port);
+	cm_set_av_port(src, NULL);
 	dest->pkey_index = src->pkey_index;
 	rdma_move_ah_attr(&dest->ah_attr, &src->ah_attr);
 	dest->timeout = src->timeout;
@@ -551,6 +600,7 @@ static void cm_move_av_from_path(struct cm_av *dest, struct cm_av *src)
 static void cm_destroy_av(struct cm_av *av)
 {
 	rdma_destroy_ah_attr(&av->ah_attr);
+	cm_set_av_port(av, NULL);
 }
 
 static u32 cm_local_id(__be32 local_id)
@@ -1275,10 +1325,18 @@ EXPORT_SYMBOL(ib_cm_insert_listen);
 
 static __be64 cm_form_tid(struct cm_id_private *cm_id_priv)
 {
-	u64 hi_tid, low_tid;
+	u64 hi_tid = 0, low_tid;
+
+	lockdep_assert_held(&cm_id_priv->lock);
+
+	low_tid = (u64)cm_id_priv->id.local_id;
+	if (!cm_id_priv->av.port)
+		return cpu_to_be64(low_tid);
 
-	hi_tid   = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32;
-	low_tid  = (u64)cm_id_priv->id.local_id;
+	spin_lock(&cm_id_priv->av.port->cm_dev->mad_agent_lock);
+	if (cm_id_priv->av.port->mad_agent)
+		hi_tid = ((u64)cm_id_priv->av.port->mad_agent->hi_tid) << 32;
+	spin_unlock(&cm_id_priv->av.port->cm_dev->mad_agent_lock);
 	return cpu_to_be64(hi_tid | low_tid);
 }
 
@@ -2139,6 +2197,9 @@ static int cm_req_handler(struct cm_work *work)
 		sa_path_set_dmac(&work->path[0],
 				 cm_id_priv->av.ah_attr.roce.dmac);
 	work->path[0].hop_limit = grh->hop_limit;
+
+	/* This destroy call is needed to pair with cm_init_av_for_response */
+	cm_destroy_av(&cm_id_priv->av);
 	ret = cm_init_av_by_path(&work->path[0], gid_attr, &cm_id_priv->av);
 	if (ret) {
 		int err;
@@ -4090,7 +4151,8 @@ static int cm_init_qp_init_attr(struct cm_id_private *cm_id_priv,
 			qp_attr->qp_access_flags |= IB_ACCESS_REMOTE_READ |
 						    IB_ACCESS_REMOTE_ATOMIC;
 		qp_attr->pkey_index = cm_id_priv->av.pkey_index;
-		qp_attr->port_num = cm_id_priv->av.port->port_num;
+		if (cm_id_priv->av.port)
+			qp_attr->port_num = cm_id_priv->av.port->port_num;
 		ret = 0;
 		break;
 	default:
@@ -4132,7 +4194,8 @@ static int cm_init_qp_rtr_attr(struct cm_id_private *cm_id_priv,
 					cm_id_priv->responder_resources;
 			qp_attr->min_rnr_timer = 0;
 		}
-		if (rdma_ah_get_dlid(&cm_id_priv->alt_av.ah_attr)) {
+		if (rdma_ah_get_dlid(&cm_id_priv->alt_av.ah_attr) &&
+		    cm_id_priv->alt_av.port) {
 			*qp_attr_mask |= IB_QP_ALT_PATH;
 			qp_attr->alt_port_num = cm_id_priv->alt_av.port->port_num;
 			qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index;
@@ -4193,7 +4256,9 @@ static int cm_init_qp_rts_attr(struct cm_id_private *cm_id_priv,
 			}
 		} else {
 			*qp_attr_mask = IB_QP_ALT_PATH | IB_QP_PATH_MIG_STATE;
-			qp_attr->alt_port_num = cm_id_priv->alt_av.port->port_num;
+			if (cm_id_priv->alt_av.port)
+				qp_attr->alt_port_num =
+					cm_id_priv->alt_av.port->port_num;
 			qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index;
 			qp_attr->alt_timeout = cm_id_priv->alt_av.timeout;
 			qp_attr->alt_ah_attr = cm_id_priv->alt_av.ah_attr;
@@ -4311,6 +4376,8 @@ static int cm_add_one(struct ib_device *ib_device)
 	if (!cm_dev)
 		return -ENOMEM;
 
+	kref_init(&cm_dev->kref);
+	spin_lock_init(&cm_dev->mad_agent_lock);
 	cm_dev->ib_device = ib_device;
 	cm_dev->ack_delay = ib_device->attrs.local_ca_ack_delay;
 	cm_dev->going_down = 0;
@@ -4373,7 +4440,6 @@ static int cm_add_one(struct ib_device *ib_device)
 error1:
 	port_modify.set_port_cap_mask = 0;
 	port_modify.clr_port_cap_mask = IB_PORT_CM_SUP;
-	kfree(port);
 	while (--i) {
 		if (!rdma_cap_ib_cm(ib_device, i))
 			continue;
@@ -4382,10 +4448,9 @@ static int cm_add_one(struct ib_device *ib_device)
 		ib_modify_port(ib_device, port->port_num, 0, &port_modify);
 		ib_unregister_mad_agent(port->mad_agent);
 		cm_remove_port_fs(port);
-		kfree(port);
 	}
 free:
-	kfree(cm_dev);
+	cm_device_put(cm_dev);
 	return ret;
 }
 
@@ -4408,10 +4473,13 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data)
 	spin_unlock_irq(&cm.lock);
 
 	rdma_for_each_port (ib_device, i) {
+		struct ib_mad_agent *mad_agent;
+
 		if (!rdma_cap_ib_cm(ib_device, i))
 			continue;
 
 		port = cm_dev->port[i-1];
+		mad_agent = port->mad_agent;
 		ib_modify_port(ib_device, port->port_num, 0, &port_modify);
 		/*
 		 * We flush the queue here after the going_down set, this
@@ -4419,12 +4487,18 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data)
 		 * after that we can call the unregister_mad_agent
 		 */
 		flush_workqueue(cm.wq);
-		ib_unregister_mad_agent(port->mad_agent);
+		/*
+		 * The above ensures no call paths from the work are running,
+		 * the remaining paths all take the mad_agent_lock.
+		 */
+		spin_lock(&cm_dev->mad_agent_lock);
+		port->mad_agent = NULL;
+		spin_unlock(&cm_dev->mad_agent_lock);
+		ib_unregister_mad_agent(mad_agent);
 		cm_remove_port_fs(port);
-		kfree(port);
 	}
 
-	kfree(cm_dev);
+	cm_device_put(cm_dev);
 }
 
 static int __init ib_cm_init(void)
-- 
GitLab