diff --git a/net/sched/cls_flower.c b/net/sched/cls_flower.c index d36ceb5001f9c98574850f2fa10183e2283a078c..9ed7c9b804a7598222a5bbd01f6cd8be06c33e2e 100644 --- a/net/sched/cls_flower.c +++ b/net/sched/cls_flower.c @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -104,6 +105,11 @@ struct cls_fl_filter { u32 in_hw_count; struct rcu_work rwork; struct net_device *hw_dev; + /* Flower classifier is unlocked, which means that its reference counter + * can be changed concurrently without any kind of external + * synchronization. Use atomic reference counter to be concurrency-safe. + */ + refcount_t refcnt; }; static const struct rhashtable_params mask_ht_params = { @@ -447,6 +453,48 @@ static struct cls_fl_head *fl_head_dereference(struct tcf_proto *tp) return rcu_dereference_raw(tp->root); } +static void __fl_put(struct cls_fl_filter *f) +{ + if (!refcount_dec_and_test(&f->refcnt)) + return; + + if (tcf_exts_get_net(&f->exts)) + tcf_queue_work(&f->rwork, fl_destroy_filter_work); + else + __fl_destroy_filter(f); +} + +static struct cls_fl_filter *__fl_get(struct cls_fl_head *head, u32 handle) +{ + struct cls_fl_filter *f; + + rcu_read_lock(); + f = idr_find(&head->handle_idr, handle); + if (f && !refcount_inc_not_zero(&f->refcnt)) + f = NULL; + rcu_read_unlock(); + + return f; +} + +static struct cls_fl_filter *fl_get_next_filter(struct tcf_proto *tp, + unsigned long *handle) +{ + struct cls_fl_head *head = fl_head_dereference(tp); + struct cls_fl_filter *f; + + rcu_read_lock(); + while ((f = idr_get_next_ul(&head->handle_idr, handle))) { + /* don't return filters that are being deleted */ + if (refcount_inc_not_zero(&f->refcnt)) + break; + ++(*handle); + } + rcu_read_unlock(); + + return f; +} + static bool __fl_delete(struct tcf_proto *tp, struct cls_fl_filter *f, struct netlink_ext_ack *extack) { @@ -460,10 +508,7 @@ static bool __fl_delete(struct tcf_proto *tp, struct cls_fl_filter *f, if (!tc_skip_hw(f->flags)) fl_hw_destroy_filter(tp, f, extack); tcf_unbind_filter(tp, &f->res); - if (async) - tcf_queue_work(&f->rwork, fl_destroy_filter_work); - else - __fl_destroy_filter(f); + __fl_put(f); return last; } @@ -498,11 +543,18 @@ static void fl_destroy(struct tcf_proto *tp, bool rtnl_held, tcf_queue_work(&head->rwork, fl_destroy_sleepable); } +static void fl_put(struct tcf_proto *tp, void *arg) +{ + struct cls_fl_filter *f = arg; + + __fl_put(f); +} + static void *fl_get(struct tcf_proto *tp, u32 handle) { struct cls_fl_head *head = fl_head_dereference(tp); - return idr_find(&head->handle_idr, handle); + return __fl_get(head, handle); } static const struct nla_policy fl_policy[TCA_FLOWER_MAX + 1] = { @@ -1325,12 +1377,16 @@ static int fl_change(struct net *net, struct sk_buff *in_skb, struct nlattr **tb; int err; - if (!tca[TCA_OPTIONS]) - return -EINVAL; + if (!tca[TCA_OPTIONS]) { + err = -EINVAL; + goto errout_fold; + } mask = kzalloc(sizeof(struct fl_flow_mask), GFP_KERNEL); - if (!mask) - return -ENOBUFS; + if (!mask) { + err = -ENOBUFS; + goto errout_fold; + } tb = kcalloc(TCA_FLOWER_MAX + 1, sizeof(struct nlattr *), GFP_KERNEL); if (!tb) { @@ -1353,6 +1409,7 @@ static int fl_change(struct net *net, struct sk_buff *in_skb, err = -ENOBUFS; goto errout_tb; } + refcount_set(&fnew->refcnt, 1); err = tcf_exts_init(&fnew->exts, net, TCA_FLOWER_ACT, 0); if (err < 0) @@ -1385,6 +1442,7 @@ static int fl_change(struct net *net, struct sk_buff *in_skb, if (!tc_in_hw(fnew->flags)) fnew->flags |= TCA_CLS_FLAGS_NOT_IN_HW; + refcount_inc(&fnew->refcnt); if (fold) { fnew->handle = handle; @@ -1403,7 +1461,11 @@ static int fl_change(struct net *net, struct sk_buff *in_skb, fl_hw_destroy_filter(tp, fold, NULL); tcf_unbind_filter(tp, &fold->res); tcf_exts_get_net(&fold->exts); - tcf_queue_work(&fold->rwork, fl_destroy_filter_work); + /* Caller holds reference to fold, so refcnt is always > 0 + * after this. + */ + refcount_dec(&fold->refcnt); + __fl_put(fold); } else { if (__fl_lookup(fnew->mask, &fnew->mkey)) { err = -EEXIST; @@ -1452,6 +1514,9 @@ static int fl_change(struct net *net, struct sk_buff *in_skb, kfree(tb); errout_mask_alloc: kfree(mask); +errout_fold: + if (fold) + __fl_put(fold); return err; } @@ -1465,24 +1530,26 @@ static int fl_delete(struct tcf_proto *tp, void *arg, bool *last, f->mask->filter_ht_params); __fl_delete(tp, f, extack); *last = list_empty(&head->masks); + __fl_put(f); + return 0; } static void fl_walk(struct tcf_proto *tp, struct tcf_walker *arg, bool rtnl_held) { - struct cls_fl_head *head = fl_head_dereference(tp); struct cls_fl_filter *f; arg->count = arg->skip; - while ((f = idr_get_next_ul(&head->handle_idr, - &arg->cookie)) != NULL) { + while ((f = fl_get_next_filter(tp, &arg->cookie)) != NULL) { if (arg->fn(tp, f, arg) < 0) { + __fl_put(f); arg->stop = 1; break; } - arg->cookie = f->handle + 1; + __fl_put(f); + arg->cookie++; arg->count++; } } @@ -2156,6 +2223,7 @@ static struct tcf_proto_ops cls_fl_ops __read_mostly = { .init = fl_init, .destroy = fl_destroy, .get = fl_get, + .put = fl_put, .change = fl_change, .delete = fl_delete, .walk = fl_walk,