提交 a8c9486b 编写于 作者: E Eric Dumazet 提交者: David S. Miller

ipmr: RCU protection for mfc_cache_array

Use RCU & RTNL protection for mfc_cache_array[]

ipmr_cache_find() is called under rcu_read_lock();
Signed-off-by: NEric Dumazet <eric.dumazet@gmail.com>
Signed-off-by: NDavid S. Miller <davem@davemloft.net>
上级 4c968709
...@@ -213,6 +213,7 @@ struct mfc_cache { ...@@ -213,6 +213,7 @@ struct mfc_cache {
unsigned char ttls[MAXVIFS]; /* TTL thresholds */ unsigned char ttls[MAXVIFS]; /* TTL thresholds */
} res; } res;
} mfc_un; } mfc_un;
struct rcu_head rcu;
}; };
#define MFC_STATIC 1 #define MFC_STATIC 1
......
...@@ -577,11 +577,18 @@ static int vif_delete(struct mr_table *mrt, int vifi, int notify, ...@@ -577,11 +577,18 @@ static int vif_delete(struct mr_table *mrt, int vifi, int notify,
return 0; return 0;
} }
static inline void ipmr_cache_free(struct mfc_cache *c) static void ipmr_cache_free_rcu(struct rcu_head *head)
{ {
struct mfc_cache *c = container_of(head, struct mfc_cache, rcu);
kmem_cache_free(mrt_cachep, c); kmem_cache_free(mrt_cachep, c);
} }
static inline void ipmr_cache_free(struct mfc_cache *c)
{
call_rcu(&c->rcu, ipmr_cache_free_rcu);
}
/* Destroy an unresolved cache entry, killing queued skbs /* Destroy an unresolved cache entry, killing queued skbs
and reporting error to netlink readers. and reporting error to netlink readers.
*/ */
...@@ -781,6 +788,7 @@ static int vif_add(struct net *net, struct mr_table *mrt, ...@@ -781,6 +788,7 @@ static int vif_add(struct net *net, struct mr_table *mrt,
return 0; return 0;
} }
/* called with rcu_read_lock() */
static struct mfc_cache *ipmr_cache_find(struct mr_table *mrt, static struct mfc_cache *ipmr_cache_find(struct mr_table *mrt,
__be32 origin, __be32 origin,
__be32 mcastgrp) __be32 mcastgrp)
...@@ -788,7 +796,7 @@ static struct mfc_cache *ipmr_cache_find(struct mr_table *mrt, ...@@ -788,7 +796,7 @@ static struct mfc_cache *ipmr_cache_find(struct mr_table *mrt,
int line = MFC_HASH(mcastgrp, origin); int line = MFC_HASH(mcastgrp, origin);
struct mfc_cache *c; struct mfc_cache *c;
list_for_each_entry(c, &mrt->mfc_cache_array[line], list) { list_for_each_entry_rcu(c, &mrt->mfc_cache_array[line], list) {
if (c->mfc_origin == origin && c->mfc_mcastgrp == mcastgrp) if (c->mfc_origin == origin && c->mfc_mcastgrp == mcastgrp)
return c; return c;
} }
...@@ -801,19 +809,20 @@ static struct mfc_cache *ipmr_cache_find(struct mr_table *mrt, ...@@ -801,19 +809,20 @@ static struct mfc_cache *ipmr_cache_find(struct mr_table *mrt,
static struct mfc_cache *ipmr_cache_alloc(void) static struct mfc_cache *ipmr_cache_alloc(void)
{ {
struct mfc_cache *c = kmem_cache_zalloc(mrt_cachep, GFP_KERNEL); struct mfc_cache *c = kmem_cache_zalloc(mrt_cachep, GFP_KERNEL);
if (c == NULL)
return NULL; if (c)
c->mfc_un.res.minvif = MAXVIFS; c->mfc_un.res.minvif = MAXVIFS;
return c; return c;
} }
static struct mfc_cache *ipmr_cache_alloc_unres(void) static struct mfc_cache *ipmr_cache_alloc_unres(void)
{ {
struct mfc_cache *c = kmem_cache_zalloc(mrt_cachep, GFP_ATOMIC); struct mfc_cache *c = kmem_cache_zalloc(mrt_cachep, GFP_ATOMIC);
if (c == NULL)
return NULL; if (c) {
skb_queue_head_init(&c->mfc_un.unres.unresolved); skb_queue_head_init(&c->mfc_un.unres.unresolved);
c->mfc_un.unres.expires = jiffies + 10*HZ; c->mfc_un.unres.expires = jiffies + 10*HZ;
}
return c; return c;
} }
...@@ -1040,9 +1049,7 @@ static int ipmr_mfc_delete(struct mr_table *mrt, struct mfcctl *mfc) ...@@ -1040,9 +1049,7 @@ static int ipmr_mfc_delete(struct mr_table *mrt, struct mfcctl *mfc)
list_for_each_entry_safe(c, next, &mrt->mfc_cache_array[line], list) { list_for_each_entry_safe(c, next, &mrt->mfc_cache_array[line], list) {
if (c->mfc_origin == mfc->mfcc_origin.s_addr && if (c->mfc_origin == mfc->mfcc_origin.s_addr &&
c->mfc_mcastgrp == mfc->mfcc_mcastgrp.s_addr) { c->mfc_mcastgrp == mfc->mfcc_mcastgrp.s_addr) {
write_lock_bh(&mrt_lock); list_del_rcu(&c->list);
list_del(&c->list);
write_unlock_bh(&mrt_lock);
ipmr_cache_free(c); ipmr_cache_free(c);
return 0; return 0;
...@@ -1095,9 +1102,7 @@ static int ipmr_mfc_add(struct net *net, struct mr_table *mrt, ...@@ -1095,9 +1102,7 @@ static int ipmr_mfc_add(struct net *net, struct mr_table *mrt,
if (!mrtsock) if (!mrtsock)
c->mfc_flags |= MFC_STATIC; c->mfc_flags |= MFC_STATIC;
write_lock_bh(&mrt_lock); list_add_rcu(&c->list, &mrt->mfc_cache_array[line]);
list_add(&c->list, &mrt->mfc_cache_array[line]);
write_unlock_bh(&mrt_lock);
/* /*
* Check to see if we resolved a queued list. If so we * Check to see if we resolved a queued list. If so we
...@@ -1149,12 +1154,9 @@ static void mroute_clean_tables(struct mr_table *mrt) ...@@ -1149,12 +1154,9 @@ static void mroute_clean_tables(struct mr_table *mrt)
*/ */
for (i = 0; i < MFC_LINES; i++) { for (i = 0; i < MFC_LINES; i++) {
list_for_each_entry_safe(c, next, &mrt->mfc_cache_array[i], list) { list_for_each_entry_safe(c, next, &mrt->mfc_cache_array[i], list) {
if (c->mfc_flags&MFC_STATIC) if (c->mfc_flags & MFC_STATIC)
continue; continue;
write_lock_bh(&mrt_lock); list_del_rcu(&c->list);
list_del(&c->list);
write_unlock_bh(&mrt_lock);
ipmr_cache_free(c); ipmr_cache_free(c);
} }
} }
...@@ -1422,19 +1424,19 @@ int ipmr_ioctl(struct sock *sk, int cmd, void __user *arg) ...@@ -1422,19 +1424,19 @@ int ipmr_ioctl(struct sock *sk, int cmd, void __user *arg)
if (copy_from_user(&sr, arg, sizeof(sr))) if (copy_from_user(&sr, arg, sizeof(sr)))
return -EFAULT; return -EFAULT;
read_lock(&mrt_lock); rcu_read_lock();
c = ipmr_cache_find(mrt, sr.src.s_addr, sr.grp.s_addr); c = ipmr_cache_find(mrt, sr.src.s_addr, sr.grp.s_addr);
if (c) { if (c) {
sr.pktcnt = c->mfc_un.res.pkt; sr.pktcnt = c->mfc_un.res.pkt;
sr.bytecnt = c->mfc_un.res.bytes; sr.bytecnt = c->mfc_un.res.bytes;
sr.wrong_if = c->mfc_un.res.wrong_if; sr.wrong_if = c->mfc_un.res.wrong_if;
read_unlock(&mrt_lock); rcu_read_unlock();
if (copy_to_user(arg, &sr, sizeof(sr))) if (copy_to_user(arg, &sr, sizeof(sr)))
return -EFAULT; return -EFAULT;
return 0; return 0;
} }
read_unlock(&mrt_lock); rcu_read_unlock();
return -EADDRNOTAVAIL; return -EADDRNOTAVAIL;
default: default:
return -ENOIOCTLCMD; return -ENOIOCTLCMD;
...@@ -1764,7 +1766,7 @@ int ip_mr_input(struct sk_buff *skb) ...@@ -1764,7 +1766,7 @@ int ip_mr_input(struct sk_buff *skb)
} }
} }
read_lock(&mrt_lock); /* already under rcu_read_lock() */
cache = ipmr_cache_find(mrt, ip_hdr(skb)->saddr, ip_hdr(skb)->daddr); cache = ipmr_cache_find(mrt, ip_hdr(skb)->saddr, ip_hdr(skb)->daddr);
/* /*
...@@ -1776,13 +1778,12 @@ int ip_mr_input(struct sk_buff *skb) ...@@ -1776,13 +1778,12 @@ int ip_mr_input(struct sk_buff *skb)
if (local) { if (local) {
struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC);
ip_local_deliver(skb); ip_local_deliver(skb);
if (skb2 == NULL) { if (skb2 == NULL)
read_unlock(&mrt_lock);
return -ENOBUFS; return -ENOBUFS;
}
skb = skb2; skb = skb2;
} }
read_lock(&mrt_lock);
vif = ipmr_find_vif(mrt, skb->dev); vif = ipmr_find_vif(mrt, skb->dev);
if (vif >= 0) { if (vif >= 0) {
int err2 = ipmr_cache_unresolved(mrt, vif, skb); int err2 = ipmr_cache_unresolved(mrt, vif, skb);
...@@ -1795,8 +1796,8 @@ int ip_mr_input(struct sk_buff *skb) ...@@ -1795,8 +1796,8 @@ int ip_mr_input(struct sk_buff *skb)
return -ENODEV; return -ENODEV;
} }
read_lock(&mrt_lock);
ip_mr_forward(net, mrt, skb, cache, local); ip_mr_forward(net, mrt, skb, cache, local);
read_unlock(&mrt_lock); read_unlock(&mrt_lock);
if (local) if (local)
...@@ -1963,7 +1964,7 @@ int ipmr_get_route(struct net *net, ...@@ -1963,7 +1964,7 @@ int ipmr_get_route(struct net *net,
if (mrt == NULL) if (mrt == NULL)
return -ENOENT; return -ENOENT;
read_lock(&mrt_lock); rcu_read_lock();
cache = ipmr_cache_find(mrt, rt->rt_src, rt->rt_dst); cache = ipmr_cache_find(mrt, rt->rt_src, rt->rt_dst);
if (cache == NULL) { if (cache == NULL) {
...@@ -1973,18 +1974,21 @@ int ipmr_get_route(struct net *net, ...@@ -1973,18 +1974,21 @@ int ipmr_get_route(struct net *net,
int vif; int vif;
if (nowait) { if (nowait) {
read_unlock(&mrt_lock); rcu_read_unlock();
return -EAGAIN; return -EAGAIN;
} }
dev = skb->dev; dev = skb->dev;
read_lock(&mrt_lock);
if (dev == NULL || (vif = ipmr_find_vif(mrt, dev)) < 0) { if (dev == NULL || (vif = ipmr_find_vif(mrt, dev)) < 0) {
read_unlock(&mrt_lock); read_unlock(&mrt_lock);
rcu_read_unlock();
return -ENODEV; return -ENODEV;
} }
skb2 = skb_clone(skb, GFP_ATOMIC); skb2 = skb_clone(skb, GFP_ATOMIC);
if (!skb2) { if (!skb2) {
read_unlock(&mrt_lock); read_unlock(&mrt_lock);
rcu_read_unlock();
return -ENOMEM; return -ENOMEM;
} }
...@@ -1997,13 +2001,16 @@ int ipmr_get_route(struct net *net, ...@@ -1997,13 +2001,16 @@ int ipmr_get_route(struct net *net,
iph->version = 0; iph->version = 0;
err = ipmr_cache_unresolved(mrt, vif, skb2); err = ipmr_cache_unresolved(mrt, vif, skb2);
read_unlock(&mrt_lock); read_unlock(&mrt_lock);
rcu_read_unlock();
return err; return err;
} }
if (!nowait && (rtm->rtm_flags&RTM_F_NOTIFY)) read_lock(&mrt_lock);
if (!nowait && (rtm->rtm_flags & RTM_F_NOTIFY))
cache->mfc_flags |= MFC_NOTIFY; cache->mfc_flags |= MFC_NOTIFY;
err = __ipmr_fill_mroute(mrt, skb, cache, rtm); err = __ipmr_fill_mroute(mrt, skb, cache, rtm);
read_unlock(&mrt_lock); read_unlock(&mrt_lock);
rcu_read_unlock();
return err; return err;
} }
...@@ -2055,14 +2062,14 @@ static int ipmr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb) ...@@ -2055,14 +2062,14 @@ static int ipmr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb)
s_h = cb->args[1]; s_h = cb->args[1];
s_e = cb->args[2]; s_e = cb->args[2];
read_lock(&mrt_lock); rcu_read_lock();
ipmr_for_each_table(mrt, net) { ipmr_for_each_table(mrt, net) {
if (t < s_t) if (t < s_t)
goto next_table; goto next_table;
if (t > s_t) if (t > s_t)
s_h = 0; s_h = 0;
for (h = s_h; h < MFC_LINES; h++) { for (h = s_h; h < MFC_LINES; h++) {
list_for_each_entry(mfc, &mrt->mfc_cache_array[h], list) { list_for_each_entry_rcu(mfc, &mrt->mfc_cache_array[h], list) {
if (e < s_e) if (e < s_e)
goto next_entry; goto next_entry;
if (ipmr_fill_mroute(mrt, skb, if (ipmr_fill_mroute(mrt, skb,
...@@ -2080,7 +2087,7 @@ static int ipmr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb) ...@@ -2080,7 +2087,7 @@ static int ipmr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb)
t++; t++;
} }
done: done:
read_unlock(&mrt_lock); rcu_read_unlock();
cb->args[2] = e; cb->args[2] = e;
cb->args[1] = h; cb->args[1] = h;
...@@ -2213,14 +2220,14 @@ static struct mfc_cache *ipmr_mfc_seq_idx(struct net *net, ...@@ -2213,14 +2220,14 @@ static struct mfc_cache *ipmr_mfc_seq_idx(struct net *net,
struct mr_table *mrt = it->mrt; struct mr_table *mrt = it->mrt;
struct mfc_cache *mfc; struct mfc_cache *mfc;
read_lock(&mrt_lock); rcu_read_lock();
for (it->ct = 0; it->ct < MFC_LINES; it->ct++) { for (it->ct = 0; it->ct < MFC_LINES; it->ct++) {
it->cache = &mrt->mfc_cache_array[it->ct]; it->cache = &mrt->mfc_cache_array[it->ct];
list_for_each_entry(mfc, it->cache, list) list_for_each_entry_rcu(mfc, it->cache, list)
if (pos-- == 0) if (pos-- == 0)
return mfc; return mfc;
} }
read_unlock(&mrt_lock); rcu_read_unlock();
spin_lock_bh(&mfc_unres_lock); spin_lock_bh(&mfc_unres_lock);
it->cache = &mrt->mfc_unres_queue; it->cache = &mrt->mfc_unres_queue;
...@@ -2279,7 +2286,7 @@ static void *ipmr_mfc_seq_next(struct seq_file *seq, void *v, loff_t *pos) ...@@ -2279,7 +2286,7 @@ static void *ipmr_mfc_seq_next(struct seq_file *seq, void *v, loff_t *pos)
} }
/* exhausted cache_array, show unresolved */ /* exhausted cache_array, show unresolved */
read_unlock(&mrt_lock); rcu_read_unlock();
it->cache = &mrt->mfc_unres_queue; it->cache = &mrt->mfc_unres_queue;
it->ct = 0; it->ct = 0;
...@@ -2302,7 +2309,7 @@ static void ipmr_mfc_seq_stop(struct seq_file *seq, void *v) ...@@ -2302,7 +2309,7 @@ static void ipmr_mfc_seq_stop(struct seq_file *seq, void *v)
if (it->cache == &mrt->mfc_unres_queue) if (it->cache == &mrt->mfc_unres_queue)
spin_unlock_bh(&mfc_unres_lock); spin_unlock_bh(&mfc_unres_lock);
else if (it->cache == &mrt->mfc_cache_array[it->ct]) else if (it->cache == &mrt->mfc_cache_array[it->ct])
read_unlock(&mrt_lock); rcu_read_unlock();
} }
static int ipmr_mfc_seq_show(struct seq_file *seq, void *v) static int ipmr_mfc_seq_show(struct seq_file *seq, void *v)
...@@ -2426,7 +2433,7 @@ int __init ip_mr_init(void) ...@@ -2426,7 +2433,7 @@ int __init ip_mr_init(void)
mrt_cachep = kmem_cache_create("ip_mrt_cache", mrt_cachep = kmem_cache_create("ip_mrt_cache",
sizeof(struct mfc_cache), sizeof(struct mfc_cache),
0, SLAB_HWCACHE_ALIGN|SLAB_PANIC, 0, SLAB_HWCACHE_ALIGN | SLAB_PANIC,
NULL); NULL);
if (!mrt_cachep) if (!mrt_cachep)
return -ENOMEM; return -ENOMEM;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册