diff --git a/arch/arm64/include/asm/fpsimd.h b/arch/arm64/include/asm/fpsimd.h index 9bbd74c0ea6be964aaf409fc892e74c8da9c74da..86f550ce7b4dcd3d3df5665eed3218a10bede6a8 100644 --- a/arch/arm64/include/asm/fpsimd.h +++ b/arch/arm64/include/asm/fpsimd.h @@ -20,6 +20,7 @@ #ifndef __ASSEMBLY__ +#include #include /* @@ -70,17 +71,24 @@ extern void fpsimd_update_current_state(struct fpsimd_state *state); extern void fpsimd_flush_task_state(struct task_struct *target); +/* Maximum VL that SVE VL-agnostic software can transparently support */ +#define SVE_VL_ARCH_MAX 0x100 + extern void sve_save_state(void *state, u32 *pfpsr); extern void sve_load_state(void const *state, u32 const *pfpsr, unsigned long vq_minus_1); extern unsigned int sve_get_vl(void); +extern int __ro_after_init sve_max_vl; + #ifdef CONFIG_ARM64_SVE extern size_t sve_state_size(struct task_struct const *task); extern void sve_alloc(struct task_struct *task); extern void fpsimd_release_task(struct task_struct *task); +extern int sve_set_vector_length(struct task_struct *task, + unsigned long vl, unsigned long flags); #else /* ! CONFIG_ARM64_SVE */ diff --git a/arch/arm64/kernel/fpsimd.c b/arch/arm64/kernel/fpsimd.c index e7733fb19388db6fa00de2b8bc7859d5be0b9752..667be3472114285b285ab33c970ebccbc5e33b77 100644 --- a/arch/arm64/kernel/fpsimd.c +++ b/arch/arm64/kernel/fpsimd.c @@ -17,8 +17,10 @@ * along with this program. If not, see . */ +#include #include #include +#include #include #include #include @@ -28,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -114,6 +117,20 @@ static DEFINE_PER_CPU(struct fpsimd_state *, fpsimd_last_state); /* Default VL for tasks that don't set it explicitly: */ static int sve_default_vl = SVE_VL_MIN; +#ifdef CONFIG_ARM64_SVE + +/* Maximum supported vector length across all CPUs (initially poisoned) */ +int __ro_after_init sve_max_vl = -1; +/* Set of available vector lengths, as vq_to_bit(vq): */ +static DECLARE_BITMAP(sve_vq_map, SVE_VQ_MAX); + +#else /* ! CONFIG_ARM64_SVE */ + +/* Dummy declaration for code that will be optimised out: */ +extern DECLARE_BITMAP(sve_vq_map, SVE_VQ_MAX); + +#endif /* ! CONFIG_ARM64_SVE */ + /* * Call __sve_free() directly only if you know task can't be scheduled * or preempted. @@ -271,6 +288,50 @@ static void task_fpsimd_save(void) } } +/* + * Helpers to translate bit indices in sve_vq_map to VQ values (and + * vice versa). This allows find_next_bit() to be used to find the + * _maximum_ VQ not exceeding a certain value. + */ + +static unsigned int vq_to_bit(unsigned int vq) +{ + return SVE_VQ_MAX - vq; +} + +static unsigned int bit_to_vq(unsigned int bit) +{ + if (WARN_ON(bit >= SVE_VQ_MAX)) + bit = SVE_VQ_MAX - 1; + + return SVE_VQ_MAX - bit; +} + +/* + * All vector length selection from userspace comes through here. + * We're on a slow path, so some sanity-checks are included. + * If things go wrong there's a bug somewhere, but try to fall back to a + * safe choice. + */ +static unsigned int find_supported_vector_length(unsigned int vl) +{ + int bit; + int max_vl = sve_max_vl; + + if (WARN_ON(!sve_vl_valid(vl))) + vl = SVE_VL_MIN; + + if (WARN_ON(!sve_vl_valid(max_vl))) + max_vl = SVE_VL_MIN; + + if (vl > max_vl) + vl = max_vl; + + bit = find_next_bit(sve_vq_map, SVE_VQ_MAX, + vq_to_bit(sve_vq_from_vl(vl))); + return sve_vl_from_vq(bit_to_vq(bit)); +} + #define ZREG(sve_state, vq, n) ((char *)(sve_state) + \ (SVE_SIG_ZREG_OFFSET(vq, n) - SVE_SIG_REGS_OFFSET)) @@ -365,6 +426,76 @@ void sve_alloc(struct task_struct *task) BUG_ON(!task->thread.sve_state); } +int sve_set_vector_length(struct task_struct *task, + unsigned long vl, unsigned long flags) +{ + if (flags & ~(unsigned long)(PR_SVE_VL_INHERIT | + PR_SVE_SET_VL_ONEXEC)) + return -EINVAL; + + if (!sve_vl_valid(vl)) + return -EINVAL; + + /* + * Clamp to the maximum vector length that VL-agnostic SVE code can + * work with. A flag may be assigned in the future to allow setting + * of larger vector lengths without confusing older software. + */ + if (vl > SVE_VL_ARCH_MAX) + vl = SVE_VL_ARCH_MAX; + + vl = find_supported_vector_length(vl); + + if (flags & (PR_SVE_VL_INHERIT | + PR_SVE_SET_VL_ONEXEC)) + task->thread.sve_vl_onexec = vl; + else + /* Reset VL to system default on next exec: */ + task->thread.sve_vl_onexec = 0; + + /* Only actually set the VL if not deferred: */ + if (flags & PR_SVE_SET_VL_ONEXEC) + goto out; + + if (vl == task->thread.sve_vl) + goto out; + + /* + * To ensure the FPSIMD bits of the SVE vector registers are preserved, + * write any live register state back to task_struct, and convert to a + * non-SVE thread. + */ + if (task == current) { + local_bh_disable(); + + task_fpsimd_save(); + set_thread_flag(TIF_FOREIGN_FPSTATE); + } + + fpsimd_flush_task_state(task); + if (test_and_clear_tsk_thread_flag(task, TIF_SVE)) + sve_to_fpsimd(task); + + if (task == current) + local_bh_enable(); + + /* + * Force reallocation of task SVE state to the correct size + * on next use: + */ + sve_free(task); + + task->thread.sve_vl = vl; + +out: + if (flags & PR_SVE_VL_INHERIT) + set_tsk_thread_flag(task, TIF_SVE_VL_INHERIT); + else + clear_tsk_thread_flag(task, TIF_SVE_VL_INHERIT); + + return 0; +} + /* * Called from the put_task_struct() path, which cannot get here * unless dead_task is really dead and not schedulable. @@ -481,7 +612,7 @@ void fpsimd_thread_switch(struct task_struct *next) void fpsimd_flush_thread(void) { - int vl; + int vl, supported_vl; if (!system_supports_fpsimd()) return; @@ -509,6 +640,10 @@ void fpsimd_flush_thread(void) if (WARN_ON(!sve_vl_valid(vl))) vl = SVE_VL_MIN; + supported_vl = find_supported_vector_length(vl); + if (WARN_ON(supported_vl != vl)) + vl = supported_vl; + current->thread.sve_vl = vl; /* diff --git a/include/uapi/linux/prctl.h b/include/uapi/linux/prctl.h index a8d0759a9e400c5d472fe37d13a924bc9e9777a6..1b64901ca6b3429631eaa576613819ee26814b45 100644 --- a/include/uapi/linux/prctl.h +++ b/include/uapi/linux/prctl.h @@ -197,4 +197,9 @@ struct prctl_mm_map { # define PR_CAP_AMBIENT_LOWER 3 # define PR_CAP_AMBIENT_CLEAR_ALL 4 +/* arm64 Scalable Vector Extension controls */ +# define PR_SVE_SET_VL_ONEXEC (1 << 18) /* defer effect until exec */ +# define PR_SVE_VL_LEN_MASK 0xffff +# define PR_SVE_VL_INHERIT (1 << 17) /* inherit across exec */ + #endif /* _LINUX_PRCTL_H */