diff --git a/drivers/net/vrf.c b/drivers/net/vrf.c index 0ea29345eb2e92cd2fb4be55584d9f0410444ac1..dff08842f26d034dc21a7a73b8a1c9b0cb72236a 100644 --- a/drivers/net/vrf.c +++ b/drivers/net/vrf.c @@ -43,8 +43,8 @@ #define DRV_VERSION "1.0" struct net_vrf { - struct rtable *rth; - struct rt6_info *rt6; + struct rtable __rcu *rth; + struct rt6_info __rcu *rt6; u32 tb_id; }; @@ -273,10 +273,15 @@ static int vrf_output6(struct net *net, struct sock *sk, struct sk_buff *skb) !(IP6CB(skb)->flags & IP6SKB_REROUTED)); } +/* holding rtnl */ static void vrf_rt6_release(struct net_vrf *vrf) { - dst_release(&vrf->rt6->dst); - vrf->rt6 = NULL; + struct rt6_info *rt6 = rtnl_dereference(vrf->rt6); + + rcu_assign_pointer(vrf->rt6, NULL); + + if (rt6) + dst_release(&rt6->dst); } static int vrf_rt6_create(struct net_device *dev) @@ -300,7 +305,8 @@ static int vrf_rt6_create(struct net_device *dev) rt6->rt6i_table = rt6i_table; rt6->dst.output = vrf_output6; - vrf->rt6 = rt6; + rcu_assign_pointer(vrf->rt6, rt6); + rc = 0; out: return rc; @@ -374,29 +380,35 @@ static int vrf_output(struct net *net, struct sock *sk, struct sk_buff *skb) !(IPCB(skb)->flags & IPSKB_REROUTED)); } +/* holding rtnl */ static void vrf_rtable_release(struct net_vrf *vrf) { - struct dst_entry *dst = (struct dst_entry *)vrf->rth; + struct rtable *rth = rtnl_dereference(vrf->rth); + + rcu_assign_pointer(vrf->rth, NULL); - dst_release(dst); - vrf->rth = NULL; + if (rth) + dst_release(&rth->dst); } -static struct rtable *vrf_rtable_create(struct net_device *dev) +static int vrf_rtable_create(struct net_device *dev) { struct net_vrf *vrf = netdev_priv(dev); struct rtable *rth; if (!fib_new_table(dev_net(dev), vrf->tb_id)) - return NULL; + return -ENOMEM; rth = rt_dst_alloc(dev, 0, RTN_UNICAST, 1, 1, 0); - if (rth) { - rth->dst.output = vrf_output; - rth->rt_table_id = vrf->tb_id; - } + if (!rth) + return -ENOMEM; - return rth; + rth->dst.output = vrf_output; + rth->rt_table_id = vrf->tb_id; + + rcu_assign_pointer(vrf->rth, rth); + + return 0; } /**************************** device handling ********************/ @@ -484,8 +496,7 @@ static int vrf_dev_init(struct net_device *dev) goto out_nomem; /* create the default dst which points back to us */ - vrf->rth = vrf_rtable_create(dev); - if (!vrf->rth) + if (vrf_rtable_create(dev) != 0) goto out_stats; if (vrf_rt6_create(dev) != 0) @@ -528,8 +539,13 @@ static struct rtable *vrf_get_rtable(const struct net_device *dev, if (!(fl4->flowi4_flags & FLOWI_FLAG_L3MDEV_SRC)) { struct net_vrf *vrf = netdev_priv(dev); - rth = vrf->rth; - dst_hold(&rth->dst); + rcu_read_lock(); + + rth = rcu_dereference(vrf->rth); + if (likely(rth)) + dst_hold(&rth->dst); + + rcu_read_unlock(); } return rth; @@ -665,16 +681,24 @@ static struct sk_buff *vrf_l3_rcv(struct net_device *vrf_dev, static struct dst_entry *vrf_get_rt6_dst(const struct net_device *dev, const struct flowi6 *fl6) { - struct rt6_info *rt = NULL; + struct dst_entry *dst = NULL; if (!(fl6->flowi6_flags & FLOWI_FLAG_L3MDEV_SRC)) { struct net_vrf *vrf = netdev_priv(dev); + struct rt6_info *rt; + + rcu_read_lock(); + + rt = rcu_dereference(vrf->rt6); + if (likely(rt)) { + dst = &rt->dst; + dst_hold(dst); + } - rt = vrf->rt6; - dst_hold(&rt->dst); + rcu_read_unlock(); } - return (struct dst_entry *)rt; + return dst; } #endif