diff --git a/drivers/net/vxlan.c b/drivers/net/vxlan.c
index 911066299a836f53107c3c82fa098abf9bee15da..901eef428280ca7c00de3619a28cd214359466ab 100644
--- a/drivers/net/vxlan.c
+++ b/drivers/net/vxlan.c
@@ -188,7 +188,7 @@ static inline struct vxlan_rdst *first_remote_rtnl(struct vxlan_fdb *fdb)
  * and enabled unshareable flags.
  */
 static struct vxlan_sock *vxlan_find_sock(struct net *net, sa_family_t family,
-					  __be16 port, u32 flags)
+					  __be16 port, u32 flags, int ifindex)
 {
 	struct vxlan_sock *vs;
 
@@ -197,7 +197,8 @@ static struct vxlan_sock *vxlan_find_sock(struct net *net, sa_family_t family,
 	hlist_for_each_entry_rcu(vs, vs_head(net, port), hlist) {
 		if (inet_sk(vs->sock->sk)->inet_sport == port &&
 		    vxlan_get_sk_family(vs) == family &&
-		    vs->flags == flags)
+		    vs->flags == flags &&
+		    vs->sock->sk->sk_bound_dev_if == ifindex)
 			return vs;
 	}
 	return NULL;
@@ -237,7 +238,7 @@ static struct vxlan_dev *vxlan_find_vni(struct net *net, int ifindex,
 {
 	struct vxlan_sock *vs;
 
-	vs = vxlan_find_sock(net, family, port, flags);
+	vs = vxlan_find_sock(net, family, port, flags, ifindex);
 	if (!vs)
 		return NULL;
 
@@ -2288,6 +2289,9 @@ static void vxlan_xmit_one(struct sk_buff *skb, struct net_device *dev,
 		struct rtable *rt;
 		__be16 df = 0;
 
+		if (!ifindex)
+			ifindex = sock4->sock->sk->sk_bound_dev_if;
+
 		rt = vxlan_get_route(vxlan, dev, sock4, skb, ifindex, tos,
 				     dst->sin.sin_addr.s_addr,
 				     &local_ip.sin.sin_addr.s_addr,
@@ -2337,6 +2341,9 @@ static void vxlan_xmit_one(struct sk_buff *skb, struct net_device *dev,
 	} else {
 		struct vxlan_sock *sock6 = rcu_dereference(vxlan->vn6_sock);
 
+		if (!ifindex)
+			ifindex = sock6->sock->sk->sk_bound_dev_if;
+
 		ndst = vxlan6_get_route(vxlan, dev, sock6, skb, ifindex, tos,
 					label, &dst->sin6.sin6_addr,
 					&local_ip.sin6.sin6_addr,
@@ -2951,7 +2958,7 @@ static const struct ethtool_ops vxlan_ethtool_ops = {
 };
 
 static struct socket *vxlan_create_sock(struct net *net, bool ipv6,
-					__be16 port, u32 flags)
+					__be16 port, u32 flags, int ifindex)
 {
 	struct socket *sock;
 	struct udp_port_cfg udp_conf;
@@ -2969,6 +2976,7 @@ static struct socket *vxlan_create_sock(struct net *net, bool ipv6,
 	}
 
 	udp_conf.local_udp_port = port;
+	udp_conf.bind_ifindex = ifindex;
 
 	/* Open UDP socket */
 	err = udp_sock_create(net, &udp_conf, &sock);
@@ -2980,7 +2988,8 @@ static struct socket *vxlan_create_sock(struct net *net, bool ipv6,
 
 /* Create new listen socket if needed */
 static struct vxlan_sock *vxlan_socket_create(struct net *net, bool ipv6,
-					      __be16 port, u32 flags)
+					      __be16 port, u32 flags,
+					      int ifindex)
 {
 	struct vxlan_net *vn = net_generic(net, vxlan_net_id);
 	struct vxlan_sock *vs;
@@ -2995,7 +3004,7 @@ static struct vxlan_sock *vxlan_socket_create(struct net *net, bool ipv6,
 	for (h = 0; h < VNI_HASH_SIZE; ++h)
 		INIT_HLIST_HEAD(&vs->vni_list[h]);
 
-	sock = vxlan_create_sock(net, ipv6, port, flags);
+	sock = vxlan_create_sock(net, ipv6, port, flags, ifindex);
 	if (IS_ERR(sock)) {
 		kfree(vs);
 		return ERR_CAST(sock);
@@ -3033,11 +3042,17 @@ static int __vxlan_sock_add(struct vxlan_dev *vxlan, bool ipv6)
 	struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
 	struct vxlan_sock *vs = NULL;
 	struct vxlan_dev_node *node;
+	int l3mdev_index = 0;
+
+	if (vxlan->cfg.remote_ifindex)
+		l3mdev_index = l3mdev_master_upper_ifindex_by_index(
+			vxlan->net, vxlan->cfg.remote_ifindex);
 
 	if (!vxlan->cfg.no_share) {
 		spin_lock(&vn->sock_lock);
 		vs = vxlan_find_sock(vxlan->net, ipv6 ? AF_INET6 : AF_INET,
-				     vxlan->cfg.dst_port, vxlan->cfg.flags);
+				     vxlan->cfg.dst_port, vxlan->cfg.flags,
+				     l3mdev_index);
 		if (vs && !refcount_inc_not_zero(&vs->refcnt)) {
 			spin_unlock(&vn->sock_lock);
 			return -EBUSY;
@@ -3046,7 +3061,8 @@ static int __vxlan_sock_add(struct vxlan_dev *vxlan, bool ipv6)
 	}
 	if (!vs)
 		vs = vxlan_socket_create(vxlan->net, ipv6,
-					 vxlan->cfg.dst_port, vxlan->cfg.flags);
+					 vxlan->cfg.dst_port, vxlan->cfg.flags,
+					 l3mdev_index);
 	if (IS_ERR(vs))
 		return PTR_ERR(vs);
 #if IS_ENABLED(CONFIG_IPV6)
diff --git a/include/net/l3mdev.h b/include/net/l3mdev.h
index 3832099289c5aa607ed89bc6eb711a8e412980f0..78fa0ac4613c3946590677f53d946d37d97f7bde 100644
--- a/include/net/l3mdev.h
+++ b/include/net/l3mdev.h
@@ -101,6 +101,17 @@ struct net_device *l3mdev_master_dev_rcu(const struct net_device *_dev)
 	return master;
 }
 
+int l3mdev_master_upper_ifindex_by_index_rcu(struct net *net, int ifindex);
+static inline
+int l3mdev_master_upper_ifindex_by_index(struct net *net, int ifindex)
+{
+	rcu_read_lock();
+	ifindex = l3mdev_master_upper_ifindex_by_index_rcu(net, ifindex);
+	rcu_read_unlock();
+
+	return ifindex;
+}
+
 u32 l3mdev_fib_table_rcu(const struct net_device *dev);
 u32 l3mdev_fib_table_by_index(struct net *net, int ifindex);
 static inline u32 l3mdev_fib_table(const struct net_device *dev)
@@ -207,6 +218,17 @@ static inline int l3mdev_master_ifindex_by_index(struct net *net, int ifindex)
 	return 0;
 }
 
+static inline
+int l3mdev_master_upper_ifindex_by_index_rcu(struct net *net, int ifindex)
+{
+	return 0;
+}
+static inline
+int l3mdev_master_upper_ifindex_by_index(struct net *net, int ifindex)
+{
+	return 0;
+}
+
 static inline
 struct net_device *l3mdev_master_dev_rcu(const struct net_device *dev)
 {
diff --git a/include/net/udp_tunnel.h b/include/net/udp_tunnel.h
index dc8d804af3b4d283f88d0d68b7a8c5410cdaac16..b8137953fea31a377b3f5cc4aae99db932ad1ed4 100644
--- a/include/net/udp_tunnel.h
+++ b/include/net/udp_tunnel.h
@@ -30,6 +30,7 @@ struct udp_port_cfg {
 
 	__be16			local_udp_port;
 	__be16			peer_udp_port;
+	int			bind_ifindex;
 	unsigned int		use_udp_checksums:1,
 				use_udp6_tx_checksums:1,
 				use_udp6_rx_checksums:1,
diff --git a/net/ipv4/udp_tunnel.c b/net/ipv4/udp_tunnel.c
index d0c412fc56adcdff4c0df10d5e1546e6ba1368eb..be8b5b2157d8a50b7502d9f229e0e78f59b5bb5c 100644
--- a/net/ipv4/udp_tunnel.c
+++ b/net/ipv4/udp_tunnel.c
@@ -20,6 +20,23 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg,
 	if (err < 0)
 		goto error;
 
+	if (cfg->bind_ifindex) {
+		struct net_device *dev;
+
+		dev = dev_get_by_index(net, cfg->bind_ifindex);
+		if (!dev) {
+			err = -ENODEV;
+			goto error;
+		}
+
+		err = kernel_setsockopt(sock, SOL_SOCKET, SO_BINDTODEVICE,
+					dev->name, strlen(dev->name) + 1);
+		dev_put(dev);
+
+		if (err < 0)
+			goto error;
+	}
+
 	udp_addr.sin_family = AF_INET;
 	udp_addr.sin_addr = cfg->local_ip;
 	udp_addr.sin_port = cfg->local_udp_port;
diff --git a/net/ipv6/ip6_udp_tunnel.c b/net/ipv6/ip6_udp_tunnel.c
index b283f293ee4ae7537da0bde51b5a4695a2e6f249..3965d5396b0a381a9d5883a97c467a9b1c3d408c 100644
--- a/net/ipv6/ip6_udp_tunnel.c
+++ b/net/ipv6/ip6_udp_tunnel.c
@@ -31,6 +31,22 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg,
 		if (err < 0)
 			goto error;
 	}
+	if (cfg->bind_ifindex) {
+		struct net_device *dev;
+
+		dev = dev_get_by_index(net, cfg->bind_ifindex);
+		if (!dev) {
+			err = -ENODEV;
+			goto error;
+		}
+
+		err = kernel_setsockopt(sock, SOL_SOCKET, SO_BINDTODEVICE,
+					dev->name, strlen(dev->name) + 1);
+		dev_put(dev);
+
+		if (err < 0)
+			goto error;
+	}
 
 	udp6_addr.sin6_family = AF_INET6;
 	memcpy(&udp6_addr.sin6_addr, &cfg->local_ip6,
diff --git a/net/l3mdev/l3mdev.c b/net/l3mdev/l3mdev.c
index 8da86ceca33ddf4c4e5d61850ce54f22a29e6fb9..309dee76724e505c804da9195f7f9d3c4cdfbdee 100644
--- a/net/l3mdev/l3mdev.c
+++ b/net/l3mdev/l3mdev.c
@@ -46,6 +46,24 @@ int l3mdev_master_ifindex_rcu(const struct net_device *dev)
 }
 EXPORT_SYMBOL_GPL(l3mdev_master_ifindex_rcu);
 
+/**
+ *	l3mdev_master_upper_ifindex_by_index - get index of upper l3 master
+ *					       device
+ *	@net: network namespace for device index lookup
+ *	@ifindex: targeted interface
+ */
+int l3mdev_master_upper_ifindex_by_index_rcu(struct net *net, int ifindex)
+{
+	struct net_device *dev;
+
+	dev = dev_get_by_index_rcu(net, ifindex);
+	while (dev && !netif_is_l3_master(dev))
+		dev = netdev_master_upper_dev_get(dev);
+
+	return dev ? dev->ifindex : 0;
+}
+EXPORT_SYMBOL_GPL(l3mdev_master_upper_ifindex_by_index_rcu);
+
 /**
  *	l3mdev_fib_table - get FIB table id associated with an L3
  *                             master interface
diff --git a/tools/testing/selftests/net/Makefile b/tools/testing/selftests/net/Makefile
index 7aebbcaa91bf4a2dae7eb058b8d4696f1b2dfb54..ee2e27b1cd0d36b291d8db627c79f9c7678e9365 100644
--- a/tools/testing/selftests/net/Makefile
+++ b/tools/testing/selftests/net/Makefile
@@ -7,7 +7,7 @@ CFLAGS += -I../../../../usr/include/
 TEST_PROGS := run_netsocktests run_afpackettests test_bpf.sh netdevice.sh rtnetlink.sh
 TEST_PROGS += fib_tests.sh fib-onlink-tests.sh pmtu.sh udpgso.sh ip_defrag.sh
 TEST_PROGS += udpgso_bench.sh fib_rule_tests.sh msg_zerocopy.sh psock_snd.sh
-TEST_PROGS += udpgro_bench.sh udpgro.sh
+TEST_PROGS += udpgro_bench.sh udpgro.sh test_vxlan_under_vrf.sh
 TEST_PROGS_EXTENDED := in_netns.sh
 TEST_GEN_FILES =  socket
 TEST_GEN_FILES += psock_fanout psock_tpacket msg_zerocopy
diff --git a/tools/testing/selftests/net/test_vxlan_under_vrf.sh b/tools/testing/selftests/net/test_vxlan_under_vrf.sh
new file mode 100755
index 0000000000000000000000000000000000000000..09f9ed92cbe4c8b6e1837698aa273b2f7848145b
--- /dev/null
+++ b/tools/testing/selftests/net/test_vxlan_under_vrf.sh
@@ -0,0 +1,129 @@
+#!/bin/bash
+# SPDX-License-Identifier: GPL-2.0
+
+# This test is for checking VXLAN underlay in a non-default VRF.
+#
+# It simulates two hypervisors running a VM each using four network namespaces:
+# two for the HVs, two for the VMs.
+# A small VXLAN tunnel is made between the two hypervisors to have the two vms
+# in the same virtual L2:
+#
+# +-------------------+                                    +-------------------+
+# |                   |                                    |                   |
+# |    vm-1 netns     |                                    |    vm-2 netns     |
+# |                   |                                    |                   |
+# |  +-------------+  |                                    |  +-------------+  |
+# |  |   veth-hv   |  |                                    |  |   veth-hv   |  |
+# |  | 10.0.0.1/24 |  |                                    |  | 10.0.0.2/24 |  |
+# |  +-------------+  |                                    |  +-------------+  |
+# |        .          |                                    |         .         |
+# +-------------------+                                    +-------------------+
+#          .                                                         .
+#          .                                                         .
+#          .                                                         .
+# +-----------------------------------+   +------------------------------------+
+# |        .                          |   |                          .         |
+# |  +----------+                     |   |                     +----------+   |
+# |  | veth-tap |                     |   |                     | veth-tap |   |
+# |  +----+-----+                     |   |                     +----+-----+   |
+# |       |                           |   |                          |         |
+# |    +--+--+      +--------------+  |   |  +--------------+     +--+--+      |
+# |    | br0 |      | vrf-underlay |  |   |  | vrf-underlay |     | br0 |      |
+# |    +--+--+      +-------+------+  |   |  +------+-------+     +--+--+      |
+# |       |                 |         |   |         |                |         |
+# |   +---+----+    +-------+-------+ |   | +-------+-------+    +---+----+    |
+# |   | vxlan0 |....|     veth0     |.|...|.|     veth0     |....| vxlan0 |    |
+# |   +--------+    | 172.16.0.1/24 | |   | | 172.16.0.2/24 |    +--------+    |
+# |                 +---------------+ |   | +---------------+                  |
+# |                                   |   |                                    |
+# |             hv-1 netns            |   |           hv-2 netns               |
+# |                                   |   |                                    |
+# +-----------------------------------+   +------------------------------------+
+#
+# This tests both the connectivity between vm-1 and vm-2, and that the underlay
+# can be moved in and out of the vrf by unsetting and setting veth0's master.
+
+set -e
+
+cleanup() {
+    ip link del veth-hv-1 2>/dev/null || true
+    ip link del veth-tap 2>/dev/null || true
+
+    for ns in hv-1 hv-2 vm-1 vm-2; do
+        ip netns del $ns || true
+    done
+}
+
+# Clean start
+cleanup &> /dev/null
+
+[[ $1 == "clean" ]] && exit 0
+
+trap cleanup EXIT
+
+# Setup "Hypervisors" simulated with netns
+ip link add veth-hv-1 type veth peer name veth-hv-2
+setup-hv-networking() {
+    hv=$1
+
+    ip netns add hv-$hv
+    ip link set veth-hv-$hv netns hv-$hv
+    ip -netns hv-$hv link set veth-hv-$hv name veth0
+
+    ip -netns hv-$hv link add vrf-underlay type vrf table 1
+    ip -netns hv-$hv link set vrf-underlay up
+    ip -netns hv-$hv addr add 172.16.0.$hv/24 dev veth0
+    ip -netns hv-$hv link set veth0 up
+
+    ip -netns hv-$hv link add br0 type bridge
+    ip -netns hv-$hv link set br0 up
+
+    ip -netns hv-$hv link add vxlan0 type vxlan id 10 local 172.16.0.$hv dev veth0 dstport 4789
+    ip -netns hv-$hv link set vxlan0 master br0
+    ip -netns hv-$hv link set vxlan0 up
+}
+setup-hv-networking 1
+setup-hv-networking 2
+
+# Check connectivity between HVs by pinging hv-2 from hv-1
+echo -n "Checking HV connectivity                                           "
+ip netns exec hv-1 ping -c 1 -W 1 172.16.0.2 &> /dev/null || (echo "[FAIL]"; false)
+echo "[ OK ]"
+
+# Setups a "VM" simulated by a netns an a veth pair
+setup-vm() {
+    id=$1
+
+    ip netns add vm-$id
+    ip link add veth-tap type veth peer name veth-hv
+
+    ip link set veth-tap netns hv-$id
+    ip -netns hv-$id link set veth-tap master br0
+    ip -netns hv-$id link set veth-tap up
+
+    ip link set veth-hv netns vm-$id
+    ip -netns vm-$id addr add 10.0.0.$id/24 dev veth-hv
+    ip -netns vm-$id link set veth-hv up
+}
+setup-vm 1
+setup-vm 2
+
+# Setup VTEP routes to make ARP work
+bridge -netns hv-1 fdb add 00:00:00:00:00:00 dev vxlan0 dst 172.16.0.2 self permanent
+bridge -netns hv-2 fdb add 00:00:00:00:00:00 dev vxlan0 dst 172.16.0.1 self permanent
+
+echo -n "Check VM connectivity through VXLAN (underlay in the default VRF)  "
+ip netns exec vm-1 ping -c 1 -W 1 10.0.0.2 &> /dev/null || (echo "[FAIL]"; false)
+echo "[ OK ]"
+
+# Move the underlay to a non-default VRF
+ip -netns hv-1 link set veth0 vrf vrf-underlay
+ip -netns hv-1 link set veth0 down
+ip -netns hv-1 link set veth0 up
+ip -netns hv-2 link set veth0 vrf vrf-underlay
+ip -netns hv-2 link set veth0 down
+ip -netns hv-2 link set veth0 up
+
+echo -n "Check VM connectivity through VXLAN (underlay in a VRF)            "
+ip netns exec vm-1 ping -c 1 -W 1 10.0.0.2 &> /dev/null || (echo "[FAIL]"; false)
+echo "[ OK ]"