diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
index 3b5d0850278f0a9df6406f4f73272a6b0475dc2c..7777528f67fb682921a83cbe5eb04e992e871c4b 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
@@ -1141,16 +1141,20 @@ static void mlxsw_sp_port_vlan_flush(struct mlxsw_sp_port *mlxsw_sp_port)
 
 	list_for_each_entry_safe(mlxsw_sp_port_vlan, tmp,
 				 &mlxsw_sp_port->vlans_list, list)
-		mlxsw_sp_port_vlan_put(mlxsw_sp_port_vlan);
+		mlxsw_sp_port_vlan_destroy(mlxsw_sp_port_vlan);
 }
 
-static struct mlxsw_sp_port_vlan *
+struct mlxsw_sp_port_vlan *
 mlxsw_sp_port_vlan_create(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid)
 {
 	struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan;
 	bool untagged = vid == 1;
 	int err;
 
+	mlxsw_sp_port_vlan = mlxsw_sp_port_vlan_find_by_vid(mlxsw_sp_port, vid);
+	if (mlxsw_sp_port_vlan)
+		return ERR_PTR(-EEXIST);
+
 	err = mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid, vid, true, untagged);
 	if (err)
 		return ERR_PTR(err);
@@ -1162,7 +1166,6 @@ mlxsw_sp_port_vlan_create(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid)
 	}
 
 	mlxsw_sp_port_vlan->mlxsw_sp_port = mlxsw_sp_port;
-	mlxsw_sp_port_vlan->ref_count = 1;
 	mlxsw_sp_port_vlan->vid = vid;
 	list_add(&mlxsw_sp_port_vlan->list, &mlxsw_sp_port->vlans_list);
 
@@ -1173,44 +1176,19 @@ mlxsw_sp_port_vlan_create(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid)
 	return ERR_PTR(err);
 }
 
-static void
-mlxsw_sp_port_vlan_destroy(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan)
+void mlxsw_sp_port_vlan_destroy(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan)
 {
 	struct mlxsw_sp_port *mlxsw_sp_port = mlxsw_sp_port_vlan->mlxsw_sp_port;
 	u16 vid = mlxsw_sp_port_vlan->vid;
 
-	list_del(&mlxsw_sp_port_vlan->list);
-	kfree(mlxsw_sp_port_vlan);
-	mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid, vid, false, false);
-}
-
-struct mlxsw_sp_port_vlan *
-mlxsw_sp_port_vlan_get(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid)
-{
-	struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan;
-
-	mlxsw_sp_port_vlan = mlxsw_sp_port_vlan_find_by_vid(mlxsw_sp_port, vid);
-	if (mlxsw_sp_port_vlan) {
-		mlxsw_sp_port_vlan->ref_count++;
-		return mlxsw_sp_port_vlan;
-	}
-
-	return mlxsw_sp_port_vlan_create(mlxsw_sp_port, vid);
-}
-
-void mlxsw_sp_port_vlan_put(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan)
-{
-	struct mlxsw_sp_fid *fid = mlxsw_sp_port_vlan->fid;
-
-	if (--mlxsw_sp_port_vlan->ref_count != 0)
-		return;
-
 	if (mlxsw_sp_port_vlan->bridge_port)
 		mlxsw_sp_port_vlan_bridge_leave(mlxsw_sp_port_vlan);
-	else if (fid)
+	else if (mlxsw_sp_port_vlan->fid)
 		mlxsw_sp_port_vlan_router_leave(mlxsw_sp_port_vlan);
 
-	mlxsw_sp_port_vlan_destroy(mlxsw_sp_port_vlan);
+	list_del(&mlxsw_sp_port_vlan->list);
+	kfree(mlxsw_sp_port_vlan);
+	mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid, vid, false, false);
 }
 
 static int mlxsw_sp_port_add_vid(struct net_device *dev,
@@ -1224,7 +1202,7 @@ static int mlxsw_sp_port_add_vid(struct net_device *dev,
 	if (!vid)
 		return 0;
 
-	return PTR_ERR_OR_ZERO(mlxsw_sp_port_vlan_get(mlxsw_sp_port, vid));
+	return PTR_ERR_OR_ZERO(mlxsw_sp_port_vlan_create(mlxsw_sp_port, vid));
 }
 
 static int mlxsw_sp_port_kill_vid(struct net_device *dev,
@@ -1242,7 +1220,7 @@ static int mlxsw_sp_port_kill_vid(struct net_device *dev,
 	mlxsw_sp_port_vlan = mlxsw_sp_port_vlan_find_by_vid(mlxsw_sp_port, vid);
 	if (!mlxsw_sp_port_vlan)
 		return 0;
-	mlxsw_sp_port_vlan_put(mlxsw_sp_port_vlan);
+	mlxsw_sp_port_vlan_destroy(mlxsw_sp_port_vlan);
 
 	return 0;
 }
@@ -3198,12 +3176,12 @@ static int mlxsw_sp_port_create(struct mlxsw_sp *mlxsw_sp, u8 local_port,
 		goto err_port_nve_init;
 	}
 
-	mlxsw_sp_port_vlan = mlxsw_sp_port_vlan_get(mlxsw_sp_port, 1);
+	mlxsw_sp_port_vlan = mlxsw_sp_port_vlan_create(mlxsw_sp_port, 1);
 	if (IS_ERR(mlxsw_sp_port_vlan)) {
 		dev_err(mlxsw_sp->bus_info->dev, "Port %d: Failed to create VID 1\n",
 			mlxsw_sp_port->local_port);
 		err = PTR_ERR(mlxsw_sp_port_vlan);
-		goto err_port_vlan_get;
+		goto err_port_vlan_create;
 	}
 
 	mlxsw_sp_port_switchdev_init(mlxsw_sp_port);
@@ -3224,8 +3202,8 @@ static int mlxsw_sp_port_create(struct mlxsw_sp *mlxsw_sp, u8 local_port,
 err_register_netdev:
 	mlxsw_sp->ports[local_port] = NULL;
 	mlxsw_sp_port_switchdev_fini(mlxsw_sp_port);
-	mlxsw_sp_port_vlan_put(mlxsw_sp_port_vlan);
-err_port_vlan_get:
+	mlxsw_sp_port_vlan_destroy(mlxsw_sp_port_vlan);
+err_port_vlan_create:
 	mlxsw_sp_port_nve_fini(mlxsw_sp_port);
 err_port_nve_init:
 	mlxsw_sp_tc_qdisc_fini(mlxsw_sp_port);
@@ -4721,7 +4699,7 @@ static void mlxsw_sp_port_lag_leave(struct mlxsw_sp_port *mlxsw_sp_port,
 	mlxsw_sp_port->lagged = 0;
 	lag->ref_count--;
 
-	mlxsw_sp_port_vlan_get(mlxsw_sp_port, 1);
+	mlxsw_sp_port_vlan_create(mlxsw_sp_port, 1);
 	/* Make sure untagged frames are allowed to ingress */
 	mlxsw_sp_port_pvid_set(mlxsw_sp_port, 1);
 }
diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum.h b/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
index 9b7ba4c3d334b1b37fffb5cafa1e740970d58ad6..cb1d37df87ea587f76bfc56e1bd5b84bcd3ec086 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
@@ -190,7 +190,6 @@ struct mlxsw_sp_port_vlan {
 	struct list_head list;
 	struct mlxsw_sp_port *mlxsw_sp_port;
 	struct mlxsw_sp_fid *fid;
-	unsigned int ref_count;
 	u16 vid;
 	struct mlxsw_sp_bridge_port *bridge_port;
 	struct list_head bridge_vlan_node;
@@ -410,8 +409,8 @@ int mlxsw_sp_port_vid_learning_set(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid,
 				   bool learn_enable);
 int mlxsw_sp_port_pvid_set(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid);
 struct mlxsw_sp_port_vlan *
-mlxsw_sp_port_vlan_get(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid);
-void mlxsw_sp_port_vlan_put(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan);
+mlxsw_sp_port_vlan_create(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid);
+void mlxsw_sp_port_vlan_destroy(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan);
 int mlxsw_sp_port_vlan_set(struct mlxsw_sp_port *mlxsw_sp_port, u16 vid_begin,
 			   u16 vid_end, bool is_member, bool untagged);
 int mlxsw_sp_flow_counter_get(struct mlxsw_sp *mlxsw_sp,
diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
index 627efe7a6eefcb8ddbd1a8d31faaaa28b06fafdb..0b2f724d9a0c97b517f6695fdd0ba0be24884991 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
@@ -1021,10 +1021,8 @@ mlxsw_sp_port_vlan_bridge_join(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan,
 	int err;
 
 	/* No need to continue if only VLAN flags were changed */
-	if (mlxsw_sp_port_vlan->bridge_port) {
-		mlxsw_sp_port_vlan_put(mlxsw_sp_port_vlan);
+	if (mlxsw_sp_port_vlan->bridge_port)
 		return 0;
-	}
 
 	err = mlxsw_sp_port_vlan_fid_join(mlxsw_sp_port_vlan, bridge_port,
 					  extack);
@@ -1105,16 +1103,32 @@ static int
 mlxsw_sp_bridge_port_vlan_add(struct mlxsw_sp_port *mlxsw_sp_port,
 			      struct mlxsw_sp_bridge_port *bridge_port,
 			      u16 vid, bool is_untagged, bool is_pvid,
-			      struct netlink_ext_ack *extack)
+			      struct netlink_ext_ack *extack,
+			      struct switchdev_trans *trans)
 {
 	u16 pvid = mlxsw_sp_port_pvid_determine(mlxsw_sp_port, vid, is_pvid);
 	struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan;
 	u16 old_pvid = mlxsw_sp_port->pvid;
 	int err;
 
-	mlxsw_sp_port_vlan = mlxsw_sp_port_vlan_get(mlxsw_sp_port, vid);
-	if (IS_ERR(mlxsw_sp_port_vlan))
-		return PTR_ERR(mlxsw_sp_port_vlan);
+	/* The only valid scenario in which a port-vlan already exists, is if
+	 * the VLAN flags were changed and the port-vlan is associated with the
+	 * correct bridge port
+	 */
+	mlxsw_sp_port_vlan = mlxsw_sp_port_vlan_find_by_vid(mlxsw_sp_port, vid);
+	if (mlxsw_sp_port_vlan &&
+	    mlxsw_sp_port_vlan->bridge_port != bridge_port)
+		return -EEXIST;
+
+	if (switchdev_trans_ph_prepare(trans))
+		return 0;
+
+	if (!mlxsw_sp_port_vlan) {
+		mlxsw_sp_port_vlan = mlxsw_sp_port_vlan_create(mlxsw_sp_port,
+							       vid);
+		if (IS_ERR(mlxsw_sp_port_vlan))
+			return PTR_ERR(mlxsw_sp_port_vlan);
+	}
 
 	err = mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid, vid, true,
 				     is_untagged);
@@ -1137,7 +1151,7 @@ mlxsw_sp_bridge_port_vlan_add(struct mlxsw_sp_port *mlxsw_sp_port,
 err_port_pvid_set:
 	mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid, vid, false, false);
 err_port_vlan_set:
-	mlxsw_sp_port_vlan_put(mlxsw_sp_port_vlan);
+	mlxsw_sp_port_vlan_destroy(mlxsw_sp_port_vlan);
 	return err;
 }
 
@@ -1199,9 +1213,6 @@ static int mlxsw_sp_port_vlans_add(struct mlxsw_sp_port *mlxsw_sp_port,
 		return err;
 	}
 
-	if (switchdev_trans_ph_prepare(trans))
-		return 0;
-
 	bridge_port = mlxsw_sp_bridge_port_find(mlxsw_sp->bridge, orig_dev);
 	if (WARN_ON(!bridge_port))
 		return -EINVAL;
@@ -1214,7 +1225,7 @@ static int mlxsw_sp_port_vlans_add(struct mlxsw_sp_port *mlxsw_sp_port,
 
 		err = mlxsw_sp_bridge_port_vlan_add(mlxsw_sp_port, bridge_port,
 						    vid, flag_untagged,
-						    flag_pvid, extack);
+						    flag_pvid, extack, trans);
 		if (err)
 			return err;
 	}
@@ -1832,7 +1843,7 @@ mlxsw_sp_bridge_port_vlan_del(struct mlxsw_sp_port *mlxsw_sp_port,
 	mlxsw_sp_port_vlan_bridge_leave(mlxsw_sp_port_vlan);
 	mlxsw_sp_port_pvid_set(mlxsw_sp_port, pvid);
 	mlxsw_sp_port_vlan_set(mlxsw_sp_port, vid, vid, false, false);
-	mlxsw_sp_port_vlan_put(mlxsw_sp_port_vlan);
+	mlxsw_sp_port_vlan_destroy(mlxsw_sp_port_vlan);
 }
 
 static int mlxsw_sp_port_vlans_del(struct mlxsw_sp_port *mlxsw_sp_port,
@@ -2000,7 +2011,7 @@ mlxsw_sp_bridge_8021q_port_join(struct mlxsw_sp_bridge_device *bridge_device,
 		return -EINVAL;
 
 	/* Let VLAN-aware bridge take care of its own VLANs */
-	mlxsw_sp_port_vlan_put(mlxsw_sp_port_vlan);
+	mlxsw_sp_port_vlan_destroy(mlxsw_sp_port_vlan);
 
 	return 0;
 }
@@ -2010,7 +2021,7 @@ mlxsw_sp_bridge_8021q_port_leave(struct mlxsw_sp_bridge_device *bridge_device,
 				 struct mlxsw_sp_bridge_port *bridge_port,
 				 struct mlxsw_sp_port *mlxsw_sp_port)
 {
-	mlxsw_sp_port_vlan_get(mlxsw_sp_port, 1);
+	mlxsw_sp_port_vlan_create(mlxsw_sp_port, 1);
 	/* Make sure untagged frames are allowed to ingress */
 	mlxsw_sp_port_pvid_set(mlxsw_sp_port, 1);
 }