diff --git a/arch/powerpc/include/asm/mmu_context.h b/arch/powerpc/include/asm/mmu_context.h
index 7d0f2d05189b4415acd76a6e9d6fbf4cc3616a3b..4d69223d217a655cb025e2a77582442be38e2bfa 100644
--- a/arch/powerpc/include/asm/mmu_context.h
+++ b/arch/powerpc/include/asm/mmu_context.h
@@ -195,6 +195,9 @@ static inline bool arch_vma_access_permitted(struct vm_area_struct *vma,
 
 #ifndef CONFIG_PPC_MEM_KEYS
 #define pkey_mm_init(mm)
+#define thread_pkey_regs_save(thread)
+#define thread_pkey_regs_restore(new_thread, old_thread)
+#define thread_pkey_regs_init(thread)
 #endif /* CONFIG_PPC_MEM_KEYS */
 
 #endif /* __KERNEL__ */
diff --git a/arch/powerpc/include/asm/pkeys.h b/arch/powerpc/include/asm/pkeys.h
index bf8e706dcb0f521e1b64ffa4c576de731727dd6e..41c303a01e1a91fb33b85a948cfba77846c0b291 100644
--- a/arch/powerpc/include/asm/pkeys.h
+++ b/arch/powerpc/include/asm/pkeys.h
@@ -139,4 +139,8 @@ static inline int arch_set_user_pkey_access(struct task_struct *tsk, int pkey,
 }
 
 extern void pkey_mm_init(struct mm_struct *mm);
+extern void thread_pkey_regs_save(struct thread_struct *thread);
+extern void thread_pkey_regs_restore(struct thread_struct *new_thread,
+				     struct thread_struct *old_thread);
+extern void thread_pkey_regs_init(struct thread_struct *thread);
 #endif /*_ASM_POWERPC_KEYS_H */
diff --git a/arch/powerpc/include/asm/processor.h b/arch/powerpc/include/asm/processor.h
index bdab3b74eb98e0eb145d904d406335daf7f0e3e6..01299cdc980676a405ab0303b469b318576037a9 100644
--- a/arch/powerpc/include/asm/processor.h
+++ b/arch/powerpc/include/asm/processor.h
@@ -309,6 +309,11 @@ struct thread_struct {
 	struct thread_vr_state ckvr_state; /* Checkpointed VR state */
 	unsigned long	ckvrsave; /* Checkpointed VRSAVE */
 #endif /* CONFIG_PPC_TRANSACTIONAL_MEM */
+#ifdef CONFIG_PPC_MEM_KEYS
+	unsigned long	amr;
+	unsigned long	iamr;
+	unsigned long	uamor;
+#endif
 #ifdef CONFIG_KVM_BOOK3S_32_HANDLER
 	void*		kvm_shadow_vcpu; /* KVM internal data */
 #endif /* CONFIG_KVM_BOOK3S_32_HANDLER */
diff --git a/arch/powerpc/kernel/process.c b/arch/powerpc/kernel/process.c
index 9eb78ec0ce5bd1d862210a95f76a9d665b41908e..755fd7e23ed4b6ba285ee7c3e053e182a17fa4f0 100644
--- a/arch/powerpc/kernel/process.c
+++ b/arch/powerpc/kernel/process.c
@@ -42,6 +42,7 @@
 #include <linux/hw_breakpoint.h>
 #include <linux/uaccess.h>
 #include <linux/elf-randomize.h>
+#include <linux/pkeys.h>
 
 #include <asm/pgtable.h>
 #include <asm/io.h>
@@ -1103,6 +1104,8 @@ static inline void save_sprs(struct thread_struct *t)
 		t->tar = mfspr(SPRN_TAR);
 	}
 #endif
+
+	thread_pkey_regs_save(t);
 }
 
 static inline void restore_sprs(struct thread_struct *old_thread,
@@ -1142,6 +1145,8 @@ static inline void restore_sprs(struct thread_struct *old_thread,
 	    old_thread->tidr != new_thread->tidr)
 		mtspr(SPRN_TIDR, new_thread->tidr);
 #endif
+
+	thread_pkey_regs_restore(new_thread, old_thread);
 }
 
 #ifdef CONFIG_PPC_BOOK3S_64
@@ -1867,6 +1872,8 @@ void start_thread(struct pt_regs *regs, unsigned long start, unsigned long sp)
 	current->thread.tm_tfiar = 0;
 	current->thread.load_tm = 0;
 #endif /* CONFIG_PPC_TRANSACTIONAL_MEM */
+
+	thread_pkey_regs_init(&current->thread);
 }
 EXPORT_SYMBOL(start_thread);
 
diff --git a/arch/powerpc/mm/pkeys.c b/arch/powerpc/mm/pkeys.c
index e61bea4a360609a57341d150ccfec37a06d861bf..91ec6c4670740c64aaa7ef1b138d19254784cfa8 100644
--- a/arch/powerpc/mm/pkeys.c
+++ b/arch/powerpc/mm/pkeys.c
@@ -12,6 +12,8 @@ DEFINE_STATIC_KEY_TRUE(pkey_disabled);
 bool pkey_execute_disable_supported;
 int  pkeys_total;		/* Total pkeys as per device tree */
 u32  initial_allocation_mask;	/* Bits set for reserved keys */
+u64  pkey_amr_uamor_mask;	/* Bits in AMR/UMOR not to be touched */
+u64  pkey_iamr_mask;		/* Bits in AMR not to be touched */
 
 #define AMR_BITS_PER_PKEY 2
 #define AMR_RD_BIT 0x1UL
@@ -70,8 +72,16 @@ int pkey_initialize(void)
 	 * programming note.
 	 */
 	initial_allocation_mask = ~0x0;
-	for (i = 2; i < (pkeys_total - os_reserved); i++)
+
+	/* register mask is in BE format */
+	pkey_amr_uamor_mask = ~0x0ul;
+	pkey_iamr_mask = ~0x0ul;
+
+	for (i = 2; i < (pkeys_total - os_reserved); i++) {
 		initial_allocation_mask &= ~(0x1 << i);
+		pkey_amr_uamor_mask &= ~(0x3ul << pkeyshift(i));
+		pkey_iamr_mask &= ~(0x1ul << pkeyshift(i));
+	}
 	return 0;
 }
 
@@ -206,3 +216,43 @@ int __arch_set_user_pkey_access(struct task_struct *tsk, int pkey,
 	init_amr(pkey, new_amr_bits);
 	return 0;
 }
+
+void thread_pkey_regs_save(struct thread_struct *thread)
+{
+	if (static_branch_likely(&pkey_disabled))
+		return;
+
+	/*
+	 * TODO: Skip saving registers if @thread hasn't used any keys yet.
+	 */
+	thread->amr = read_amr();
+	thread->iamr = read_iamr();
+	thread->uamor = read_uamor();
+}
+
+void thread_pkey_regs_restore(struct thread_struct *new_thread,
+			      struct thread_struct *old_thread)
+{
+	if (static_branch_likely(&pkey_disabled))
+		return;
+
+	/*
+	 * TODO: Just set UAMOR to zero if @new_thread hasn't used any keys yet.
+	 */
+	if (old_thread->amr != new_thread->amr)
+		write_amr(new_thread->amr);
+	if (old_thread->iamr != new_thread->iamr)
+		write_iamr(new_thread->iamr);
+	if (old_thread->uamor != new_thread->uamor)
+		write_uamor(new_thread->uamor);
+}
+
+void thread_pkey_regs_init(struct thread_struct *thread)
+{
+	if (static_branch_likely(&pkey_disabled))
+		return;
+
+	write_amr(read_amr() & pkey_amr_uamor_mask);
+	write_iamr(read_iamr() & pkey_iamr_mask);
+	write_uamor(read_uamor() & pkey_amr_uamor_mask);
+}