diff --git a/arch/arm64/include/asm/daifflags.h b/arch/arm64/include/asm/daifflags.h
index 063c964af705f0c31da0d44d10687f1e3f00ec6c..9207cd5aa39e2af6eabc04afed1a0c80351a4f51 100644
--- a/arch/arm64/include/asm/daifflags.h
+++ b/arch/arm64/include/asm/daifflags.h
@@ -9,6 +9,7 @@
 
 #include <asm/arch_gicv3.h>
 #include <asm/cpufeature.h>
+#include <asm/ptrace.h>
 
 #define DAIF_PROCCTX		0
 #define DAIF_PROCCTX_NOIRQ	PSR_I_BIT
@@ -109,4 +110,19 @@ static inline void local_daif_restore(unsigned long flags)
 		trace_hardirqs_off();
 }
 
+/*
+ * Called by synchronous exception handlers to restore the DAIF bits that were
+ * modified by taking an exception.
+ */
+static inline void local_daif_inherit(struct pt_regs *regs)
+{
+	unsigned long flags = regs->pstate & DAIF_MASK;
+
+	/*
+	 * We can't use local_daif_restore(regs->pstate) here as
+	 * system_has_prio_mask_debugging() won't restore the I bit if it can
+	 * use the pmr instead.
+	 */
+	write_sysreg(flags, daif);
+}
 #endif