diff --git a/include/linux/netdevice.h b/include/linux/netdevice.h
index 7ed3a3aa6604b7b8511ea709ea2c2312a1d38bb2..20e99efb1ca6b15f2e4f2e46f5e3c8b2e606b046 100644
--- a/include/linux/netdevice.h
+++ b/include/linux/netdevice.h
@@ -3180,12 +3180,7 @@ void netdev_change_features(struct net_device *dev);
 void netif_stacked_transfer_operstate(const struct net_device *rootdev,
 					struct net_device *dev);
 
-netdev_features_t netif_skb_dev_features(struct sk_buff *skb,
-					 const struct net_device *dev);
-static inline netdev_features_t netif_skb_features(struct sk_buff *skb)
-{
-	return netif_skb_dev_features(skb, skb->dev);
-}
+netdev_features_t netif_skb_features(struct sk_buff *skb);
 
 static inline bool net_gso_ok(netdev_features_t features, int gso_type)
 {
diff --git a/net/core/dev.c b/net/core/dev.c
index d2c8a06b3a9883b618c55a8cddf0929a4bcd049f..c619b864133789f5205c02d002f50f5d410ab456 100644
--- a/net/core/dev.c
+++ b/net/core/dev.c
@@ -2418,7 +2418,7 @@ EXPORT_SYMBOL(netdev_rx_csum_fault);
  * 2. No high memory really exists on this machine.
  */
 
-static int illegal_highdma(const struct net_device *dev, struct sk_buff *skb)
+static int illegal_highdma(struct net_device *dev, struct sk_buff *skb)
 {
 #ifdef CONFIG_HIGHMEM
 	int i;
@@ -2493,38 +2493,36 @@ static int dev_gso_segment(struct sk_buff *skb, netdev_features_t features)
 }
 
 static netdev_features_t harmonize_features(struct sk_buff *skb,
-					    const struct net_device *dev,
-					    netdev_features_t features)
+	netdev_features_t features)
 {
 	int tmp;
 
 	if (skb->ip_summed != CHECKSUM_NONE &&
 	    !can_checksum_protocol(features, skb_network_protocol(skb, &tmp))) {
 		features &= ~NETIF_F_ALL_CSUM;
-	} else if (illegal_highdma(dev, skb)) {
+	} else if (illegal_highdma(skb->dev, skb)) {
 		features &= ~NETIF_F_SG;
 	}
 
 	return features;
 }
 
-netdev_features_t netif_skb_dev_features(struct sk_buff *skb,
-					 const struct net_device *dev)
+netdev_features_t netif_skb_features(struct sk_buff *skb)
 {
 	__be16 protocol = skb->protocol;
-	netdev_features_t features = dev->features;
+	netdev_features_t features = skb->dev->features;
 
-	if (skb_shinfo(skb)->gso_segs > dev->gso_max_segs)
+	if (skb_shinfo(skb)->gso_segs > skb->dev->gso_max_segs)
 		features &= ~NETIF_F_GSO_MASK;
 
 	if (protocol == htons(ETH_P_8021Q) || protocol == htons(ETH_P_8021AD)) {
 		struct vlan_ethhdr *veh = (struct vlan_ethhdr *)skb->data;
 		protocol = veh->h_vlan_encapsulated_proto;
 	} else if (!vlan_tx_tag_present(skb)) {
-		return harmonize_features(skb, dev, features);
+		return harmonize_features(skb, features);
 	}
 
-	features &= (dev->vlan_features | NETIF_F_HW_VLAN_CTAG_TX |
+	features &= (skb->dev->vlan_features | NETIF_F_HW_VLAN_CTAG_TX |
 					       NETIF_F_HW_VLAN_STAG_TX);
 
 	if (protocol == htons(ETH_P_8021Q) || protocol == htons(ETH_P_8021AD))
@@ -2532,9 +2530,9 @@ netdev_features_t netif_skb_dev_features(struct sk_buff *skb,
 				NETIF_F_GEN_CSUM | NETIF_F_HW_VLAN_CTAG_TX |
 				NETIF_F_HW_VLAN_STAG_TX;
 
-	return harmonize_features(skb, dev, features);
+	return harmonize_features(skb, features);
 }
-EXPORT_SYMBOL(netif_skb_dev_features);
+EXPORT_SYMBOL(netif_skb_features);
 
 int dev_hard_start_xmit(struct sk_buff *skb, struct net_device *dev,
 			struct netdev_queue *txq)
diff --git a/net/ipv4/ip_forward.c b/net/ipv4/ip_forward.c
index c29ae8371e44b7215a271c4e95de2282ac787b27..6f111e48e11c15a19ffd60d345b1399fd3c66953 100644
--- a/net/ipv4/ip_forward.c
+++ b/net/ipv4/ip_forward.c
@@ -56,53 +56,6 @@ static bool ip_exceeds_mtu(const struct sk_buff *skb, unsigned int mtu)
 	return true;
 }
 
-static bool ip_gso_exceeds_dst_mtu(const struct sk_buff *skb)
-{
-	unsigned int mtu;
-
-	if (skb->local_df || !skb_is_gso(skb))
-		return false;
-
-	mtu = ip_dst_mtu_maybe_forward(skb_dst(skb), true);
-
-	/* if seglen > mtu, do software segmentation for IP fragmentation on
-	 * output.  DF bit cannot be set since ip_forward would have sent
-	 * icmp error.
-	 */
-	return skb_gso_network_seglen(skb) > mtu;
-}
-
-/* called if GSO skb needs to be fragmented on forward */
-static int ip_forward_finish_gso(struct sk_buff *skb)
-{
-	struct dst_entry *dst = skb_dst(skb);
-	netdev_features_t features;
-	struct sk_buff *segs;
-	int ret = 0;
-
-	features = netif_skb_dev_features(skb, dst->dev);
-	segs = skb_gso_segment(skb, features & ~NETIF_F_GSO_MASK);
-	if (IS_ERR(segs)) {
-		kfree_skb(skb);
-		return -ENOMEM;
-	}
-
-	consume_skb(skb);
-
-	do {
-		struct sk_buff *nskb = segs->next;
-		int err;
-
-		segs->next = NULL;
-		err = dst_output(segs);
-
-		if (err && ret == 0)
-			ret = err;
-		segs = nskb;
-	} while (segs);
-
-	return ret;
-}
 
 static int ip_forward_finish(struct sk_buff *skb)
 {
@@ -114,9 +67,6 @@ static int ip_forward_finish(struct sk_buff *skb)
 	if (unlikely(opt->optlen))
 		ip_forward_options(skb);
 
-	if (ip_gso_exceeds_dst_mtu(skb))
-		return ip_forward_finish_gso(skb);
-
 	return dst_output(skb);
 }
 
diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
index 1cbeba5edff90fa1ac891d4dd23cfb65464878a4..a52f50187b5495a1157c2ef8e0785da730414022 100644
--- a/net/ipv4/ip_output.c
+++ b/net/ipv4/ip_output.c
@@ -211,6 +211,48 @@ static inline int ip_finish_output2(struct sk_buff *skb)
 	return -EINVAL;
 }
 
+static int ip_finish_output_gso(struct sk_buff *skb)
+{
+	netdev_features_t features;
+	struct sk_buff *segs;
+	int ret = 0;
+
+	/* common case: locally created skb or seglen is <= mtu */
+	if (((IPCB(skb)->flags & IPSKB_FORWARDED) == 0) ||
+	      skb_gso_network_seglen(skb) <= ip_skb_dst_mtu(skb))
+		return ip_finish_output2(skb);
+
+	/* Slowpath -  GSO segment length is exceeding the dst MTU.
+	 *
+	 * This can happen in two cases:
+	 * 1) TCP GRO packet, DF bit not set
+	 * 2) skb arrived via virtio-net, we thus get TSO/GSO skbs directly
+	 * from host network stack.
+	 */
+	features = netif_skb_features(skb);
+	segs = skb_gso_segment(skb, features & ~NETIF_F_GSO_MASK);
+	if (IS_ERR(segs)) {
+		kfree_skb(skb);
+		return -ENOMEM;
+	}
+
+	consume_skb(skb);
+
+	do {
+		struct sk_buff *nskb = segs->next;
+		int err;
+
+		segs->next = NULL;
+		err = ip_fragment(segs, ip_finish_output2);
+
+		if (err && ret == 0)
+			ret = err;
+		segs = nskb;
+	} while (segs);
+
+	return ret;
+}
+
 static int ip_finish_output(struct sk_buff *skb)
 {
 #if defined(CONFIG_NETFILTER) && defined(CONFIG_XFRM)
@@ -220,10 +262,13 @@ static int ip_finish_output(struct sk_buff *skb)
 		return dst_output(skb);
 	}
 #endif
-	if (skb->len > ip_skb_dst_mtu(skb) && !skb_is_gso(skb))
+	if (skb_is_gso(skb))
+		return ip_finish_output_gso(skb);
+
+	if (skb->len > ip_skb_dst_mtu(skb))
 		return ip_fragment(skb, ip_finish_output2);
-	else
-		return ip_finish_output2(skb);
+
+	return ip_finish_output2(skb);
 }
 
 int ip_mc_output(struct sock *sk, struct sk_buff *skb)