diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 7b7b45a195979dc46efee171bc7a322275023ecc..c41a88100fea6f88f7b6ea059ffda39cd4033866 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -73,8 +73,12 @@ struct netlink_sock { struct netlink_callback *cb; spinlock_t cb_lock; void (*data_ready)(struct sock *sk, int bytes); + struct module *module; + u32 flags; }; +#define NETLINK_KERNEL_SOCKET 0x1 + static inline struct netlink_sock *nlk_sk(struct sock *sk) { return (struct netlink_sock *)sk; @@ -97,7 +101,7 @@ struct netlink_table { struct nl_pid_hash hash; struct hlist_head mc_list; unsigned int nl_nonroot; - struct proto_ops *p_ops; + struct module *module; }; static struct netlink_table *nl_table; @@ -338,6 +342,7 @@ static int netlink_create(struct socket *sock, int protocol) { struct sock *sk; struct netlink_sock *nlk; + struct module *module; sock->state = SS_UNCONNECTED; @@ -347,30 +352,36 @@ static int netlink_create(struct socket *sock, int protocol) if (protocol<0 || protocol >= MAX_LINKS) return -EPROTONOSUPPORT; - netlink_table_grab(); + netlink_lock_table(); if (!nl_table[protocol].hash.entries) { #ifdef CONFIG_KMOD /* We do 'best effort'. If we find a matching module, * it is loaded. If not, we don't return an error to * allow pure userspace<->userspace communication. -HW */ - netlink_table_ungrab(); + netlink_unlock_table(); request_module("net-pf-%d-proto-%d", PF_NETLINK, protocol); - netlink_table_grab(); + netlink_lock_table(); #endif } - netlink_table_ungrab(); + module = nl_table[protocol].module; + if (!try_module_get(module)) + module = NULL; + netlink_unlock_table(); - sock->ops = nl_table[protocol].p_ops; + sock->ops = &netlink_ops; sk = sk_alloc(PF_NETLINK, GFP_KERNEL, &netlink_proto, 1); - if (!sk) + if (!sk) { + module_put(module); return -ENOMEM; + } sock_init_data(sock, sk); nlk = nlk_sk(sk); + nlk->module = module; spin_lock_init(&nlk->cb_lock); init_waitqueue_head(&nlk->wait); sk->sk_destruct = netlink_sock_destruct; @@ -415,22 +426,15 @@ static int netlink_release(struct socket *sock) notifier_call_chain(&netlink_chain, NETLINK_URELEASE, &n); } - /* When this is a kernel socket, we need to remove the owner pointer, - * since we don't know whether the module will be dying at any given - * point - HW - */ - if (!nlk->pid) { - struct proto_ops *p_tmp; + if (nlk->module) + module_put(nlk->module); + if (nlk->flags & NETLINK_KERNEL_SOCKET) { netlink_table_grab(); - p_tmp = nl_table[sk->sk_protocol].p_ops; - if (p_tmp != &netlink_ops) { - nl_table[sk->sk_protocol].p_ops = &netlink_ops; - kfree(p_tmp); - } + nl_table[sk->sk_protocol].module = NULL; netlink_table_ungrab(); } - + sock_put(sk); return 0; } @@ -1060,9 +1064,9 @@ static void netlink_data_ready(struct sock *sk, int len) struct sock * netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct module *module) { - struct proto_ops *p_ops; struct socket *sock; struct sock *sk; + struct netlink_sock *nlk; if (!nl_table) return NULL; @@ -1070,64 +1074,32 @@ netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct if (unit<0 || unit>=MAX_LINKS) return NULL; - /* Do a quick check, to make us not go down to netlink_insert() - * if protocol already has kernel socket. - */ - sk = netlink_lookup(unit, 0); - if (unlikely(sk)) { - sock_put(sk); - return NULL; - } - if (sock_create_lite(PF_NETLINK, SOCK_DGRAM, unit, &sock)) return NULL; - sk = NULL; - if (module) { - /* Every registering protocol implemented in a module needs - * it's own p_ops, since the socket code cannot deal with - * module refcounting otherwise. -HW - */ - p_ops = kmalloc(sizeof(*p_ops), GFP_KERNEL); - if (!p_ops) - goto out_sock_release; - - memcpy(p_ops, &netlink_ops, sizeof(*p_ops)); - p_ops->owner = module; - } else - p_ops = &netlink_ops; - - netlink_table_grab(); - nl_table[unit].p_ops = p_ops; - netlink_table_ungrab(); - - if (netlink_create(sock, unit) < 0) { - sk = NULL; - goto out_kfree_p_ops; - } + if (netlink_create(sock, unit) < 0) + goto out_sock_release; sk = sock->sk; sk->sk_data_ready = netlink_data_ready; if (input) nlk_sk(sk)->data_ready = input; - if (netlink_insert(sk, 0)) { - sk = NULL; - goto out_kfree_p_ops; - } + if (netlink_insert(sk, 0)) + goto out_sock_release; - return sk; + nlk = nlk_sk(sk); + nlk->flags |= NETLINK_KERNEL_SOCKET; -out_kfree_p_ops: netlink_table_grab(); - if (nl_table[unit].p_ops != &netlink_ops) { - kfree(nl_table[unit].p_ops); - nl_table[unit].p_ops = &netlink_ops; - } + nl_table[unit].module = module; netlink_table_ungrab(); + + return sk; + out_sock_release: sock_release(sock); - return sk; + return NULL; } void netlink_set_nonroot(int protocol, unsigned int flags) @@ -1490,8 +1462,6 @@ static int __init netlink_proto_init(void) for (i = 0; i < MAX_LINKS; i++) { struct nl_pid_hash *hash = &nl_table[i].hash; - nl_table[i].p_ops = &netlink_ops; - hash->table = nl_pid_hash_alloc(1 * sizeof(*hash->table)); if (!hash->table) { while (i-- > 0)