diff --git a/include/linux/netlink.h b/include/linux/netlink.h index c256ebe2a7b448ffae662ee78dab5acc62ace0e1..f8f3d1c927f8b1ebd8de989ae5cf752f503356d8 100644 --- a/include/linux/netlink.h +++ b/include/linux/netlink.h @@ -151,6 +151,7 @@ struct netlink_skb_parms extern struct sock *netlink_kernel_create(int unit, unsigned int groups, void (*input)(struct sock *sk, int len), struct module *module); extern void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err); +extern int netlink_has_listeners(struct sock *sk, unsigned int group); extern int netlink_unicast(struct sock *ssk, struct sk_buff *skb, __u32 pid, int nonblock); extern int netlink_broadcast(struct sock *ssk, struct sk_buff *skb, __u32 pid, __u32 group, gfp_t allocation); diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 59dc7d140600d95ec10a495beaed0f37f4f390e5..d00a9034cb5f9177aa31ad4d533ce9490cfa77c6 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -106,6 +106,7 @@ struct nl_pid_hash { struct netlink_table { struct nl_pid_hash hash; struct hlist_head mc_list; + unsigned long *listeners; unsigned int nl_nonroot; unsigned int groups; struct module *module; @@ -296,6 +297,24 @@ static inline int nl_pid_hash_dilute(struct nl_pid_hash *hash, int len) static const struct proto_ops netlink_ops; +static void +netlink_update_listeners(struct sock *sk) +{ + struct netlink_table *tbl = &nl_table[sk->sk_protocol]; + struct hlist_node *node; + unsigned long mask; + unsigned int i; + + for (i = 0; i < NLGRPSZ(tbl->groups)/sizeof(unsigned long); i++) { + mask = 0; + sk_for_each_bound(sk, node, &tbl->mc_list) + mask |= nlk_sk(sk)->groups[i]; + tbl->listeners[i] = mask; + } + /* this function is only called with the netlink table "grabbed", which + * makes sure updates are visible before bind or setsockopt return. */ +} + static int netlink_insert(struct sock *sk, u32 pid) { struct nl_pid_hash *hash = &nl_table[sk->sk_protocol].hash; @@ -456,12 +475,14 @@ static int netlink_release(struct socket *sock) if (nlk->module) module_put(nlk->module); + netlink_table_grab(); if (nlk->flags & NETLINK_KERNEL_SOCKET) { - netlink_table_grab(); + kfree(nl_table[sk->sk_protocol].listeners); nl_table[sk->sk_protocol].module = NULL; nl_table[sk->sk_protocol].registered = 0; - netlink_table_ungrab(); - } + } else if (nlk->subscriptions) + netlink_update_listeners(sk); + netlink_table_ungrab(); kfree(nlk->groups); nlk->groups = NULL; @@ -589,6 +610,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len hweight32(nladdr->nl_groups) - hweight32(nlk->groups[0])); nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups; + netlink_update_listeners(sk); netlink_table_ungrab(); return 0; @@ -807,6 +829,17 @@ int netlink_unicast(struct sock *ssk, struct sk_buff *skb, u32 pid, int nonblock return netlink_sendskb(sk, skb, ssk->sk_protocol); } +int netlink_has_listeners(struct sock *sk, unsigned int group) +{ + int res = 0; + + BUG_ON(!(nlk_sk(sk)->flags & NETLINK_KERNEL_SOCKET)); + if (group - 1 < nl_table[sk->sk_protocol].groups) + res = test_bit(group - 1, nl_table[sk->sk_protocol].listeners); + return res; +} +EXPORT_SYMBOL_GPL(netlink_has_listeners); + static __inline__ int netlink_broadcast_deliver(struct sock *sk, struct sk_buff *skb) { struct netlink_sock *nlk = nlk_sk(sk); @@ -1011,6 +1044,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, else __clear_bit(val - 1, nlk->groups); netlink_update_subscriptions(sk, subscriptions); + netlink_update_listeners(sk); netlink_table_ungrab(); err = 0; break; @@ -1237,6 +1271,7 @@ netlink_kernel_create(int unit, unsigned int groups, struct socket *sock; struct sock *sk; struct netlink_sock *nlk; + unsigned long *listeners = NULL; if (!nl_table) return NULL; @@ -1250,6 +1285,13 @@ netlink_kernel_create(int unit, unsigned int groups, if (__netlink_create(sock, unit) < 0) goto out_sock_release; + if (groups < 32) + groups = 32; + + listeners = kzalloc(NLGRPSZ(groups), GFP_KERNEL); + if (!listeners) + goto out_sock_release; + sk = sock->sk; sk->sk_data_ready = netlink_data_ready; if (input) @@ -1262,7 +1304,8 @@ netlink_kernel_create(int unit, unsigned int groups, nlk->flags |= NETLINK_KERNEL_SOCKET; netlink_table_grab(); - nl_table[unit].groups = groups < 32 ? 32 : groups; + nl_table[unit].groups = groups; + nl_table[unit].listeners = listeners; nl_table[unit].module = module; nl_table[unit].registered = 1; netlink_table_ungrab(); @@ -1270,6 +1313,7 @@ netlink_kernel_create(int unit, unsigned int groups, return sk; out_sock_release: + kfree(listeners); sock_release(sock); return NULL; }