diff --git a/include/net/sock.h b/include/net/sock.h index 74dbbf0e30c03b0087ce342c067909932a362385..117f1126e45abf2c79839d2f897f3813a478804f 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1808,7 +1808,12 @@ void sk_common_release(struct sock *sk); * Default socket callbacks and setup code */ -/* Initialise core socket variables */ +/* Initialise core socket variables using an explicit uid. */ +void sock_init_data_uid(struct socket *sock, struct sock *sk, kuid_t uid); + +/* Initialise core socket variables. + * Assumes struct socket *sock is embedded in a struct socket_alloc. + */ void sock_init_data(struct socket *sock, struct sock *sk); /* diff --git a/net/core/sock.c b/net/core/sock.c index 56a927b9b372a9b43e2ae338d996a40b323344aa..d8d42ff15d209bd2f351d9539abdf9bf13cef4d8 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -2967,7 +2967,7 @@ void sk_stop_timer_sync(struct sock *sk, struct timer_list *timer) } EXPORT_SYMBOL(sk_stop_timer_sync); -void sock_init_data(struct socket *sock, struct sock *sk) +void sock_init_data_uid(struct socket *sock, struct sock *sk, kuid_t uid) { sk_init_common(sk); sk->sk_send_head = NULL; @@ -2986,13 +2986,12 @@ void sock_init_data(struct socket *sock, struct sock *sk) sk->sk_type = sock->type; RCU_INIT_POINTER(sk->sk_wq, &sock->wq); sock->sk = sk; - sk->sk_uid = SOCK_INODE(sock)->i_uid; sk->sk_gid = SOCK_INODE(sock)->i_gid; } else { RCU_INIT_POINTER(sk->sk_wq, NULL); - sk->sk_uid = make_kuid(sock_net(sk)->user_ns, 0); sk->sk_gid = make_kgid(sock_net(sk)->user_ns, 0); } + sk->sk_uid = uid; rwlock_init(&sk->sk_callback_lock); if (sk->sk_kern_sock) @@ -3050,6 +3049,16 @@ void sock_init_data(struct socket *sock, struct sock *sk) refcount_set(&sk->sk_refcnt, 1); atomic_set(&sk->sk_drops, 0); } +EXPORT_SYMBOL(sock_init_data_uid); + +void sock_init_data(struct socket *sock, struct sock *sk) +{ + kuid_t uid = sock ? + SOCK_INODE(sock)->i_uid : + make_kuid(sock_net(sk)->user_ns, 0); + + sock_init_data_uid(sock, sk, uid); +} EXPORT_SYMBOL(sock_init_data); void lock_sock_nested(struct sock *sk, int subclass)