diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum.h b/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
index 64cae57bfbc85df35a722e03f280b7f8a0121d10..92faae0064802352883f0d7680dd0f96add7861a 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
@@ -484,7 +484,6 @@ mlxsw_sp_netdevice_ipip_ul_event(struct mlxsw_sp *mlxsw_sp,
 				 struct netdev_notifier_info *info);
 void
 mlxsw_sp_port_vlan_router_leave(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan);
-void mlxsw_sp_rif_destroy(struct mlxsw_sp_rif *rif);
 void mlxsw_sp_rif_destroy_by_dev(struct mlxsw_sp *mlxsw_sp,
 				 struct net_device *dev);
 struct mlxsw_sp_rif *mlxsw_sp_rif_find_by_dev(const struct mlxsw_sp *mlxsw_sp,
@@ -789,6 +788,7 @@ enum mlxsw_sp_rif_type mlxsw_sp_fid_rif_type(const struct mlxsw_sp_fid *fid);
 u16 mlxsw_sp_fid_index(const struct mlxsw_sp_fid *fid);
 enum mlxsw_sp_fid_type mlxsw_sp_fid_type(const struct mlxsw_sp_fid *fid);
 void mlxsw_sp_fid_rif_set(struct mlxsw_sp_fid *fid, struct mlxsw_sp_rif *rif);
+struct mlxsw_sp_rif *mlxsw_sp_fid_rif(const struct mlxsw_sp_fid *fid);
 enum mlxsw_sp_rif_type
 mlxsw_sp_fid_type_rif_type(const struct mlxsw_sp *mlxsw_sp,
 			   enum mlxsw_sp_fid_type type);
diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum_fid.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum_fid.c
index 7adb1494ebbaed9f017d1d6a4ba665af09f234cd..ee3b8db5d9e03de01a48023726a201ac3b478574 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum_fid.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum_fid.c
@@ -369,6 +369,11 @@ void mlxsw_sp_fid_rif_set(struct mlxsw_sp_fid *fid, struct mlxsw_sp_rif *rif)
 	fid->rif = rif;
 }
 
+struct mlxsw_sp_rif *mlxsw_sp_fid_rif(const struct mlxsw_sp_fid *fid)
+{
+	return fid->rif;
+}
+
 enum mlxsw_sp_rif_type
 mlxsw_sp_fid_type_rif_type(const struct mlxsw_sp *mlxsw_sp,
 			   enum mlxsw_sp_fid_type type)
@@ -1083,20 +1088,16 @@ void mlxsw_sp_fid_put(struct mlxsw_sp_fid *fid)
 	struct mlxsw_sp_fid_family *fid_family = fid->fid_family;
 	struct mlxsw_sp *mlxsw_sp = fid_family->mlxsw_sp;
 
-	if (--fid->ref_count == 1 && fid->rif) {
-		/* Destroy the associated RIF and let it drop the last
-		 * reference on the FID.
-		 */
-		return mlxsw_sp_rif_destroy(fid->rif);
-	} else if (fid->ref_count == 0) {
-		list_del(&fid->list);
-		rhashtable_remove_fast(&mlxsw_sp->fid_core->fid_ht,
-				       &fid->ht_node, mlxsw_sp_fid_ht_params);
-		fid->fid_family->ops->deconfigure(fid);
-		__clear_bit(fid->fid_index - fid_family->start_index,
-			    fid_family->fids_bitmap);
-		kfree(fid);
-	}
+	if (--fid->ref_count != 0)
+		return;
+
+	list_del(&fid->list);
+	rhashtable_remove_fast(&mlxsw_sp->fid_core->fid_ht,
+			       &fid->ht_node, mlxsw_sp_fid_ht_params);
+	fid->fid_family->ops->deconfigure(fid);
+	__clear_bit(fid->fid_index - fid_family->start_index,
+		    fid_family->fids_bitmap);
+	kfree(fid);
 }
 
 struct mlxsw_sp_fid *mlxsw_sp_fid_8021q_get(struct mlxsw_sp *mlxsw_sp, u16 vid)
diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum_router.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum_router.c
index b33ba8b6686163989d720b6f8eb04c6edb64e1b3..8a0011a610d62f9e4b1d38b175979b3fe7c17a60 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum_router.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum_router.c
@@ -15,6 +15,7 @@
 #include <linux/gcd.h>
 #include <linux/random.h>
 #include <linux/if_macvlan.h>
+#include <linux/refcount.h>
 #include <net/netevent.h>
 #include <net/neighbour.h>
 #include <net/arp.h>
@@ -104,6 +105,7 @@ struct mlxsw_sp_rif_params {
 
 struct mlxsw_sp_rif_subport {
 	struct mlxsw_sp_rif common;
+	refcount_t ref_count;
 	union {
 		u16 system_port;
 		u16 lag_id;
@@ -136,6 +138,7 @@ struct mlxsw_sp_rif_ops {
 	void (*fdb_del)(struct mlxsw_sp_rif *rif, const char *mac);
 };
 
+static void mlxsw_sp_rif_destroy(struct mlxsw_sp_rif *rif);
 static void mlxsw_sp_lpm_tree_hold(struct mlxsw_sp_lpm_tree *lpm_tree);
 static void mlxsw_sp_lpm_tree_put(struct mlxsw_sp *mlxsw_sp,
 				  struct mlxsw_sp_lpm_tree *lpm_tree);
@@ -6343,7 +6346,7 @@ mlxsw_sp_rif_create(struct mlxsw_sp *mlxsw_sp,
 	return ERR_PTR(err);
 }
 
-void mlxsw_sp_rif_destroy(struct mlxsw_sp_rif *rif)
+static void mlxsw_sp_rif_destroy(struct mlxsw_sp_rif *rif)
 {
 	const struct mlxsw_sp_rif_ops *ops = rif->ops;
 	struct mlxsw_sp *mlxsw_sp = rif->mlxsw_sp;
@@ -6392,6 +6395,40 @@ mlxsw_sp_rif_subport_params_init(struct mlxsw_sp_rif_params *params,
 		params->system_port = mlxsw_sp_port->local_port;
 }
 
+static struct mlxsw_sp_rif_subport *
+mlxsw_sp_rif_subport_rif(const struct mlxsw_sp_rif *rif)
+{
+	return container_of(rif, struct mlxsw_sp_rif_subport, common);
+}
+
+static struct mlxsw_sp_rif *
+mlxsw_sp_rif_subport_get(struct mlxsw_sp *mlxsw_sp,
+			 const struct mlxsw_sp_rif_params *params,
+			 struct netlink_ext_ack *extack)
+{
+	struct mlxsw_sp_rif_subport *rif_subport;
+	struct mlxsw_sp_rif *rif;
+
+	rif = mlxsw_sp_rif_find_by_dev(mlxsw_sp, params->dev);
+	if (!rif)
+		return mlxsw_sp_rif_create(mlxsw_sp, params, extack);
+
+	rif_subport = mlxsw_sp_rif_subport_rif(rif);
+	refcount_inc(&rif_subport->ref_count);
+	return rif;
+}
+
+static void mlxsw_sp_rif_subport_put(struct mlxsw_sp_rif *rif)
+{
+	struct mlxsw_sp_rif_subport *rif_subport;
+
+	rif_subport = mlxsw_sp_rif_subport_rif(rif);
+	if (!refcount_dec_and_test(&rif_subport->ref_count))
+		return;
+
+	mlxsw_sp_rif_destroy(rif);
+}
+
 static int
 mlxsw_sp_port_vlan_router_join(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan,
 			       struct net_device *l3_dev,
@@ -6399,22 +6436,18 @@ mlxsw_sp_port_vlan_router_join(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan,
 {
 	struct mlxsw_sp_port *mlxsw_sp_port = mlxsw_sp_port_vlan->mlxsw_sp_port;
 	struct mlxsw_sp *mlxsw_sp = mlxsw_sp_port->mlxsw_sp;
+	struct mlxsw_sp_rif_params params = {
+		.dev = l3_dev,
+	};
 	u16 vid = mlxsw_sp_port_vlan->vid;
 	struct mlxsw_sp_rif *rif;
 	struct mlxsw_sp_fid *fid;
 	int err;
 
-	rif = mlxsw_sp_rif_find_by_dev(mlxsw_sp, l3_dev);
-	if (!rif) {
-		struct mlxsw_sp_rif_params params = {
-			.dev = l3_dev,
-		};
-
-		mlxsw_sp_rif_subport_params_init(&params, mlxsw_sp_port_vlan);
-		rif = mlxsw_sp_rif_create(mlxsw_sp, &params, extack);
-		if (IS_ERR(rif))
-			return PTR_ERR(rif);
-	}
+	mlxsw_sp_rif_subport_params_init(&params, mlxsw_sp_port_vlan);
+	rif = mlxsw_sp_rif_subport_get(mlxsw_sp, &params, extack);
+	if (IS_ERR(rif))
+		return PTR_ERR(rif);
 
 	/* FID was already created, just take a reference */
 	fid = rif->ops->fid_get(rif, extack);
@@ -6441,6 +6474,7 @@ mlxsw_sp_port_vlan_router_join(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan,
 	mlxsw_sp_fid_port_vid_unmap(fid, mlxsw_sp_port, vid);
 err_fid_port_vid_map:
 	mlxsw_sp_fid_put(fid);
+	mlxsw_sp_rif_subport_put(rif);
 	return err;
 }
 
@@ -6449,6 +6483,7 @@ mlxsw_sp_port_vlan_router_leave(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan)
 {
 	struct mlxsw_sp_port *mlxsw_sp_port = mlxsw_sp_port_vlan->mlxsw_sp_port;
 	struct mlxsw_sp_fid *fid = mlxsw_sp_port_vlan->fid;
+	struct mlxsw_sp_rif *rif = mlxsw_sp_fid_rif(fid);
 	u16 vid = mlxsw_sp_port_vlan->vid;
 
 	if (WARN_ON(mlxsw_sp_fid_type(fid) != MLXSW_SP_FID_TYPE_RFID))
@@ -6458,10 +6493,8 @@ mlxsw_sp_port_vlan_router_leave(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan)
 	mlxsw_sp_port_vid_stp_set(mlxsw_sp_port, vid, BR_STATE_BLOCKING);
 	mlxsw_sp_port_vid_learning_set(mlxsw_sp_port, vid, true);
 	mlxsw_sp_fid_port_vid_unmap(fid, mlxsw_sp_port, vid);
-	/* If router port holds the last reference on the rFID, then the
-	 * associated Sub-port RIF will be destroyed.
-	 */
 	mlxsw_sp_fid_put(fid);
+	mlxsw_sp_rif_subport_put(rif);
 }
 
 static int mlxsw_sp_inetaddr_port_vlan_event(struct net_device *l3_dev,
@@ -7064,18 +7097,13 @@ static int mlxsw_sp_rif_macvlan_flush(struct mlxsw_sp_rif *rif)
 					     __mlxsw_sp_rif_macvlan_flush, rif);
 }
 
-static struct mlxsw_sp_rif_subport *
-mlxsw_sp_rif_subport_rif(const struct mlxsw_sp_rif *rif)
-{
-	return container_of(rif, struct mlxsw_sp_rif_subport, common);
-}
-
 static void mlxsw_sp_rif_subport_setup(struct mlxsw_sp_rif *rif,
 				       const struct mlxsw_sp_rif_params *params)
 {
 	struct mlxsw_sp_rif_subport *rif_subport;
 
 	rif_subport = mlxsw_sp_rif_subport_rif(rif);
+	refcount_set(&rif_subport->ref_count, 1);
 	rif_subport->vid = params->vid;
 	rif_subport->lag = params->lag;
 	if (params->lag)