diff --git a/arch/x86/mm/mpx.c b/arch/x86/mm/mpx.c index 8cc79347903968835327a3bf8a8ec2dcbb23940f..294ea2092ef538dd22b0ecd93506ab7d3d225dfb 100644 --- a/arch/x86/mm/mpx.c +++ b/arch/x86/mm/mpx.c @@ -419,6 +419,35 @@ int mpx_disable_management(void) return 0; } +static int mpx_cmpxchg_bd_entry(struct mm_struct *mm, + unsigned long *curval, + unsigned long __user *addr, + unsigned long old_val, unsigned long new_val) +{ + int ret; + /* + * user_atomic_cmpxchg_inatomic() actually uses sizeof() + * the pointer that we pass to it to figure out how much + * data to cmpxchg. We have to be careful here not to + * pass a pointer to a 64-bit data type when we only want + * a 32-bit copy. + */ + if (is_64bit_mm(mm)) { + ret = user_atomic_cmpxchg_inatomic(curval, + addr, old_val, new_val); + } else { + u32 uninitialized_var(curval_32); + u32 old_val_32 = old_val; + u32 new_val_32 = new_val; + u32 __user *addr_32 = (u32 __user *)addr; + + ret = user_atomic_cmpxchg_inatomic(&curval_32, + addr_32, old_val_32, new_val_32); + *curval = curval_32; + } + return ret; +} + /* * With 32-bit mode, MPX_BT_SIZE_BYTES is 4MB, and the size of each * bounds table is 16KB. With 64-bit mode, MPX_BT_SIZE_BYTES is 2GB, @@ -426,6 +455,7 @@ int mpx_disable_management(void) */ static int allocate_bt(long __user *bd_entry) { + struct mm_struct *mm = current->mm; unsigned long expected_old_val = 0; unsigned long actual_old_val = 0; unsigned long bt_addr; @@ -455,8 +485,8 @@ static int allocate_bt(long __user *bd_entry) * mmap_sem at this point, unlike some of the other part * of the MPX code that have to pagefault_disable(). */ - ret = user_atomic_cmpxchg_inatomic(&actual_old_val, bd_entry, - expected_old_val, bd_new_entry); + ret = mpx_cmpxchg_bd_entry(mm, &actual_old_val, bd_entry, + expected_old_val, bd_new_entry); if (ret) goto out_unmap; @@ -710,15 +740,16 @@ static int unmap_single_bt(struct mm_struct *mm, long __user *bd_entry, unsigned long bt_addr) { unsigned long expected_old_val = bt_addr | MPX_BD_ENTRY_VALID_FLAG; - unsigned long actual_old_val = 0; + unsigned long uninitialized_var(actual_old_val); int ret; while (1) { int need_write = 1; + unsigned long cleared_bd_entry = 0; pagefault_disable(); - ret = user_atomic_cmpxchg_inatomic(&actual_old_val, bd_entry, - expected_old_val, 0); + ret = mpx_cmpxchg_bd_entry(mm, &actual_old_val, + bd_entry, expected_old_val, cleared_bd_entry); pagefault_enable(); if (!ret) break;