diff --git a/include/linux/bpfilter.h b/include/linux/bpfilter.h index 70ffeed280e9dc02d0ea1ef3766a41f937b373f6..8ebcbdd70bdc55f8c9b93aa643d06145f0a453ed 100644 --- a/include/linux/bpfilter.h +++ b/include/linux/bpfilter.h @@ -15,6 +15,8 @@ struct bpfilter_umh_ops { int (*sockopt)(struct sock *sk, int optname, char __user *optval, unsigned int optlen, bool is_set); + int (*start)(void); + bool stop; }; extern struct bpfilter_umh_ops bpfilter_ops; #endif diff --git a/net/bpfilter/bpfilter_kern.c b/net/bpfilter/bpfilter_kern.c index a68940b74c0154a76f28c6d99aeaf29ad68f0fe9..c0fcde910a7ad75277765aaf7d7476cdc7dc861e 100644 --- a/net/bpfilter/bpfilter_kern.c +++ b/net/bpfilter/bpfilter_kern.c @@ -16,13 +16,14 @@ extern char bpfilter_umh_end; /* since ip_getsockopt() can run in parallel, serialize access to umh */ static DEFINE_MUTEX(bpfilter_lock); -static void shutdown_umh(struct umh_info *info) +static void shutdown_umh(void) { struct task_struct *tsk; - if (!info->pid) + if (bpfilter_ops.stop) return; - tsk = get_pid_task(find_vpid(info->pid), PIDTYPE_PID); + + tsk = get_pid_task(find_vpid(bpfilter_ops.info.pid), PIDTYPE_PID); if (tsk) { force_sig(SIGKILL, tsk); put_task_struct(tsk); @@ -31,10 +32,8 @@ static void shutdown_umh(struct umh_info *info) static void __stop_umh(void) { - if (IS_ENABLED(CONFIG_INET)) { - bpfilter_ops.sockopt = NULL; - shutdown_umh(&bpfilter_ops.info); - } + if (IS_ENABLED(CONFIG_INET)) + shutdown_umh(); } static void stop_umh(void) @@ -85,7 +84,7 @@ static int __bpfilter_process_sockopt(struct sock *sk, int optname, return ret; } -static int __init load_umh(void) +static int start_umh(void) { int err; @@ -95,6 +94,7 @@ static int __init load_umh(void) &bpfilter_ops.info); if (err) return err; + bpfilter_ops.stop = false; pr_info("Loaded bpfilter_umh pid %d\n", bpfilter_ops.info.pid); /* health check that usermode process started correctly */ @@ -102,14 +102,31 @@ static int __init load_umh(void) stop_umh(); return -EFAULT; } - if (IS_ENABLED(CONFIG_INET)) - bpfilter_ops.sockopt = &__bpfilter_process_sockopt; return 0; } +static int __init load_umh(void) +{ + int err; + + if (!bpfilter_ops.stop) + return -EFAULT; + err = start_umh(); + if (!err && IS_ENABLED(CONFIG_INET)) { + bpfilter_ops.sockopt = &__bpfilter_process_sockopt; + bpfilter_ops.start = &start_umh; + } + + return err; +} + static void __exit fini_umh(void) { + if (IS_ENABLED(CONFIG_INET)) { + bpfilter_ops.start = NULL; + bpfilter_ops.sockopt = NULL; + } stop_umh(); } module_init(load_umh); diff --git a/net/bpfilter/bpfilter_umh_blob.S b/net/bpfilter/bpfilter_umh_blob.S index 40311d10d2f270adbb7a51cfc3bad0c2834ca438..7f1c521dcc2f75fadbc9c6241c2c60c7a1d6b24b 100644 --- a/net/bpfilter/bpfilter_umh_blob.S +++ b/net/bpfilter/bpfilter_umh_blob.S @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ - .section .init.rodata, "a" + .section .bpfilter_umh, "a" .global bpfilter_umh_start bpfilter_umh_start: .incbin "net/bpfilter/bpfilter_umh" diff --git a/net/ipv4/bpfilter/sockopt.c b/net/ipv4/bpfilter/sockopt.c index c326cfbc0f62198afae382f71f0b4a480e97aa84..de84ede4e76537ec84fb1520423f22587aed391a 100644 --- a/net/ipv4/bpfilter/sockopt.c +++ b/net/ipv4/bpfilter/sockopt.c @@ -14,6 +14,7 @@ EXPORT_SYMBOL_GPL(bpfilter_ops); static void bpfilter_umh_cleanup(struct umh_info *info) { + bpfilter_ops.stop = true; fput(info->pipe_to_umh); fput(info->pipe_from_umh); info->pid = 0; @@ -23,14 +24,21 @@ static int bpfilter_mbox_request(struct sock *sk, int optname, char __user *optval, unsigned int optlen, bool is_set) { + int err; + if (!bpfilter_ops.sockopt) { - int err = request_module("bpfilter"); + err = request_module("bpfilter"); if (err) return err; if (!bpfilter_ops.sockopt) return -ECHILD; } + if (bpfilter_ops.stop) { + err = bpfilter_ops.start(); + if (err) + return err; + } return bpfilter_ops.sockopt(sk, optname, optval, optlen, is_set); } @@ -53,6 +61,7 @@ int bpfilter_ip_get_sockopt(struct sock *sk, int optname, char __user *optval, static int __init bpfilter_sockopt_init(void) { + bpfilter_ops.stop = true; bpfilter_ops.info.cmdline = "bpfilter_umh"; bpfilter_ops.info.cleanup = &bpfilter_umh_cleanup;