diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
index 93378d5079626a7541ec3b9889e907ee9b8c36c8..9cc063af6f7b3919c0c162fff75f45b67e1dc6d7 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum.c
@@ -5174,7 +5174,7 @@ static int mlxsw_sp_netdevice_vxlan_event(struct mlxsw_sp *mlxsw_sp,
 			return mlxsw_sp_bridge_vxlan_join(mlxsw_sp, upper_dev,
 							  dev, extack);
 		} else {
-			mlxsw_sp_bridge_vxlan_leave(mlxsw_sp, upper_dev, dev);
+			mlxsw_sp_bridge_vxlan_leave(mlxsw_sp, dev);
 		}
 		break;
 	case NETDEV_PRE_UP:
@@ -5195,7 +5195,7 @@ static int mlxsw_sp_netdevice_vxlan_event(struct mlxsw_sp *mlxsw_sp,
 			return 0;
 		if (!mlxsw_sp_lower_get(upper_dev))
 			return 0;
-		mlxsw_sp_bridge_vxlan_leave(mlxsw_sp, upper_dev, dev);
+		mlxsw_sp_bridge_vxlan_leave(mlxsw_sp, dev);
 		break;
 	}
 
diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum.h b/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
index 55ed6903879652808b612af758c8ce2a7f5e04f4..696c2360fbb4263ef00ba875a798e05e8fe1bede 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum.h
@@ -361,7 +361,6 @@ int mlxsw_sp_bridge_vxlan_join(struct mlxsw_sp *mlxsw_sp,
 			       const struct net_device *vxlan_dev,
 			       struct netlink_ext_ack *extack);
 void mlxsw_sp_bridge_vxlan_leave(struct mlxsw_sp *mlxsw_sp,
-				 const struct net_device *br_dev,
 				 const struct net_device *vxlan_dev);
 
 /* spectrum.c */
diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
index 108dbb764c7713129dc9044cdd270a5da98472a7..801c1af81916623f820cbff3e5a73374b018df60 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum_switchdev.c
@@ -87,8 +87,6 @@ struct mlxsw_sp_bridge_ops {
 	int (*vxlan_join)(struct mlxsw_sp_bridge_device *bridge_device,
 			  const struct net_device *vxlan_dev,
 			  struct netlink_ext_ack *extack);
-	void (*vxlan_leave)(struct mlxsw_sp_bridge_device *bridge_device,
-			    const struct net_device *vxlan_dev);
 	struct mlxsw_sp_fid *
 		(*fid_get)(struct mlxsw_sp_bridge_device *bridge_device,
 			   u16 vid);
@@ -2012,12 +2010,6 @@ mlxsw_sp_bridge_8021q_vxlan_join(struct mlxsw_sp_bridge_device *bridge_device,
 	return -EINVAL;
 }
 
-static void
-mlxsw_sp_bridge_8021q_vxlan_leave(struct mlxsw_sp_bridge_device *bridge_device,
-				  const struct net_device *vxlan_dev)
-{
-}
-
 static struct mlxsw_sp_fid *
 mlxsw_sp_bridge_8021q_fid_get(struct mlxsw_sp_bridge_device *bridge_device,
 			      u16 vid)
@@ -2047,7 +2039,6 @@ static const struct mlxsw_sp_bridge_ops mlxsw_sp_bridge_8021q_ops = {
 	.port_join	= mlxsw_sp_bridge_8021q_port_join,
 	.port_leave	= mlxsw_sp_bridge_8021q_port_leave,
 	.vxlan_join	= mlxsw_sp_bridge_8021q_vxlan_join,
-	.vxlan_leave	= mlxsw_sp_bridge_8021q_vxlan_leave,
 	.fid_get	= mlxsw_sp_bridge_8021q_fid_get,
 	.fid_lookup	= mlxsw_sp_bridge_8021q_fid_lookup,
 	.fid_vid	= mlxsw_sp_bridge_8021q_fid_vid,
@@ -2152,26 +2143,6 @@ mlxsw_sp_bridge_8021d_vxlan_join(struct mlxsw_sp_bridge_device *bridge_device,
 	return err;
 }
 
-static void
-mlxsw_sp_bridge_8021d_vxlan_leave(struct mlxsw_sp_bridge_device *bridge_device,
-				  const struct net_device *vxlan_dev)
-{
-	struct mlxsw_sp *mlxsw_sp = mlxsw_sp_lower_get(bridge_device->dev);
-	struct mlxsw_sp_fid *fid;
-
-	fid = mlxsw_sp_fid_8021d_lookup(mlxsw_sp, bridge_device->dev->ifindex);
-	if (WARN_ON(!fid))
-		return;
-
-	/* If the VxLAN device is down, then the FID does not have a VNI */
-	if (!mlxsw_sp_fid_vni_is_set(fid))
-		goto out;
-
-	mlxsw_sp_nve_fid_disable(mlxsw_sp, fid);
-out:
-	mlxsw_sp_fid_put(fid);
-}
-
 static struct mlxsw_sp_fid *
 mlxsw_sp_bridge_8021d_fid_get(struct mlxsw_sp_bridge_device *bridge_device,
 			      u16 vid)
@@ -2230,7 +2201,6 @@ static const struct mlxsw_sp_bridge_ops mlxsw_sp_bridge_8021d_ops = {
 	.port_join	= mlxsw_sp_bridge_8021d_port_join,
 	.port_leave	= mlxsw_sp_bridge_8021d_port_leave,
 	.vxlan_join	= mlxsw_sp_bridge_8021d_vxlan_join,
-	.vxlan_leave	= mlxsw_sp_bridge_8021d_vxlan_leave,
 	.fid_get	= mlxsw_sp_bridge_8021d_fid_get,
 	.fid_lookup	= mlxsw_sp_bridge_8021d_fid_lookup,
 	.fid_vid	= mlxsw_sp_bridge_8021d_fid_vid,
@@ -2298,16 +2268,18 @@ int mlxsw_sp_bridge_vxlan_join(struct mlxsw_sp *mlxsw_sp,
 }
 
 void mlxsw_sp_bridge_vxlan_leave(struct mlxsw_sp *mlxsw_sp,
-				 const struct net_device *br_dev,
 				 const struct net_device *vxlan_dev)
 {
-	struct mlxsw_sp_bridge_device *bridge_device;
+	struct vxlan_dev *vxlan = netdev_priv(vxlan_dev);
+	struct mlxsw_sp_fid *fid;
 
-	bridge_device = mlxsw_sp_bridge_device_find(mlxsw_sp->bridge, br_dev);
-	if (WARN_ON(!bridge_device))
+	/* If the VxLAN device is down, then the FID does not have a VNI */
+	fid = mlxsw_sp_fid_lookup_by_vni(mlxsw_sp, vxlan->cfg.vni);
+	if (!fid)
 		return;
 
-	bridge_device->ops->vxlan_leave(bridge_device, vxlan_dev);
+	mlxsw_sp_nve_fid_disable(mlxsw_sp, fid);
+	mlxsw_sp_fid_put(fid);
 }
 
 static void