diff --git a/arch/Kconfig b/arch/Kconfig
index 0a0ea5066fc0cb413f2940edefa7cb289739c709..6801123932a503ba64bcf1c9dfbb7877fff0f094 100644
--- a/arch/Kconfig
+++ b/arch/Kconfig
@@ -362,6 +362,9 @@ config HAVE_ARCH_JUMP_LABEL
 config HAVE_RCU_TABLE_FREE
 	bool
 
+config HAVE_RCU_TABLE_INVALIDATE
+	bool
+
 config ARCH_HAVE_NMI_SAFE_CMPXCHG
 	bool
 
diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig
index 512003f168895c231150f943d18db545e82e9f61..c5ff296bc5d1252f6eaaf847755a3c88fed52c89 100644
--- a/arch/x86/Kconfig
+++ b/arch/x86/Kconfig
@@ -180,7 +180,8 @@ config X86
 	select HAVE_HARDLOCKUP_DETECTOR_PERF	if PERF_EVENTS && HAVE_PERF_EVENTS_NMI
 	select HAVE_PERF_REGS
 	select HAVE_PERF_USER_STACK_DUMP
-	select HAVE_RCU_TABLE_FREE
+	select HAVE_RCU_TABLE_FREE		if PARAVIRT
+	select HAVE_RCU_TABLE_INVALIDATE	if HAVE_RCU_TABLE_FREE
 	select HAVE_REGS_AND_STACK_ACCESS_API
 	select HAVE_RELIABLE_STACKTRACE		if X86_64 && (UNWINDER_FRAME_POINTER || UNWINDER_ORC) && STACK_VALIDATION
 	select HAVE_STACKPROTECTOR		if CC_HAS_SANE_STACKPROTECTOR
diff --git a/arch/x86/hyperv/mmu.c b/arch/x86/hyperv/mmu.c
index 1147e1fed7fff45e28e9c18e59e81542d1647d7f..ef5f29f913d7b064f1a086ac0674bebb4c8a845a 100644
--- a/arch/x86/hyperv/mmu.c
+++ b/arch/x86/hyperv/mmu.c
@@ -9,6 +9,7 @@
 #include <asm/mshyperv.h>
 #include <asm/msr.h>
 #include <asm/tlbflush.h>
+#include <asm/tlb.h>
 
 #define CREATE_TRACE_POINTS
 #include <asm/trace/hyperv.h>
@@ -231,4 +232,5 @@ void hyperv_setup_mmu_ops(void)
 
 	pr_info("Using hypercall for remote TLB flush\n");
 	pv_mmu_ops.flush_tlb_others = hyperv_flush_tlb_others;
+	pv_mmu_ops.tlb_remove_table = tlb_remove_table;
 }
diff --git a/arch/x86/include/asm/paravirt.h b/arch/x86/include/asm/paravirt.h
index d49bbf4bb5c867ef5a086a15b02882d694a961f6..e375d4266b53e35bbe2383d7633d478f29871c4f 100644
--- a/arch/x86/include/asm/paravirt.h
+++ b/arch/x86/include/asm/paravirt.h
@@ -309,6 +309,11 @@ static inline void flush_tlb_others(const struct cpumask *cpumask,
 	PVOP_VCALL2(pv_mmu_ops.flush_tlb_others, cpumask, info);
 }
 
+static inline void paravirt_tlb_remove_table(struct mmu_gather *tlb, void *table)
+{
+	PVOP_VCALL2(pv_mmu_ops.tlb_remove_table, tlb, table);
+}
+
 static inline int paravirt_pgd_alloc(struct mm_struct *mm)
 {
 	return PVOP_CALL1(int, pv_mmu_ops.pgd_alloc, mm);
diff --git a/arch/x86/include/asm/paravirt_types.h b/arch/x86/include/asm/paravirt_types.h
index 180bc0bff0fbd98195b8e81ce9aa7232dde06ee2..4b75acc23b30bad579578b6988140f55bc88f6b9 100644
--- a/arch/x86/include/asm/paravirt_types.h
+++ b/arch/x86/include/asm/paravirt_types.h
@@ -54,6 +54,7 @@ struct desc_struct;
 struct task_struct;
 struct cpumask;
 struct flush_tlb_info;
+struct mmu_gather;
 
 /*
  * Wrapper type for pointers to code which uses the non-standard
@@ -222,6 +223,8 @@ struct pv_mmu_ops {
 	void (*flush_tlb_others)(const struct cpumask *cpus,
 				 const struct flush_tlb_info *info);
 
+	void (*tlb_remove_table)(struct mmu_gather *tlb, void *table);
+
 	/* Hooks for allocating and freeing a pagetable top-level */
 	int  (*pgd_alloc)(struct mm_struct *mm);
 	void (*pgd_free)(struct mm_struct *mm, pgd_t *pgd);
diff --git a/arch/x86/include/asm/tlbflush.h b/arch/x86/include/asm/tlbflush.h
index 511bf5fae8b82abb7ebbd7f4173f70afae9d1538..29c9da6c62fc16b8b28bec203eb8982b2f570db5 100644
--- a/arch/x86/include/asm/tlbflush.h
+++ b/arch/x86/include/asm/tlbflush.h
@@ -148,6 +148,22 @@ static inline unsigned long build_cr3_noflush(pgd_t *pgd, u16 asid)
 #define __flush_tlb_one_user(addr) __native_flush_tlb_one_user(addr)
 #endif
 
+static inline bool tlb_defer_switch_to_init_mm(void)
+{
+	/*
+	 * If we have PCID, then switching to init_mm is reasonably
+	 * fast.  If we don't have PCID, then switching to init_mm is
+	 * quite slow, so we try to defer it in the hopes that we can
+	 * avoid it entirely.  The latter approach runs the risk of
+	 * receiving otherwise unnecessary IPIs.
+	 *
+	 * This choice is just a heuristic.  The tlb code can handle this
+	 * function returning true or false regardless of whether we have
+	 * PCID.
+	 */
+	return !static_cpu_has(X86_FEATURE_PCID);
+}
+
 struct tlb_context {
 	u64 ctx_id;
 	u64 tlb_gen;
@@ -536,11 +552,9 @@ extern void arch_tlbbatch_flush(struct arch_tlbflush_unmap_batch *batch);
 #ifndef CONFIG_PARAVIRT
 #define flush_tlb_others(mask, info)	\
 	native_flush_tlb_others(mask, info)
-#endif
 
-extern void tlb_flush_remove_tables(struct mm_struct *mm);
-extern void tlb_flush_remove_tables_local(void *arg);
-
-#define HAVE_TLB_FLUSH_REMOVE_TABLES
+#define paravirt_tlb_remove_table(tlb, page) \
+	tlb_remove_page(tlb, (void *)(page))
+#endif
 
 #endif /* _ASM_X86_TLBFLUSH_H */
diff --git a/arch/x86/kernel/kvm.c b/arch/x86/kernel/kvm.c
index 0f471bd934172e673770936e9903d3c81623a6cb..d9b71924c23c9b939986b2d2ab153ba15b7077c3 100644
--- a/arch/x86/kernel/kvm.c
+++ b/arch/x86/kernel/kvm.c
@@ -45,6 +45,7 @@
 #include <asm/apic.h>
 #include <asm/apicdef.h>
 #include <asm/hypervisor.h>
+#include <asm/tlb.h>
 
 static int kvmapf = 1;
 
@@ -636,8 +637,10 @@ static void __init kvm_guest_init(void)
 
 	if (kvm_para_has_feature(KVM_FEATURE_PV_TLB_FLUSH) &&
 	    !kvm_para_has_hint(KVM_HINTS_REALTIME) &&
-	    kvm_para_has_feature(KVM_FEATURE_STEAL_TIME))
+	    kvm_para_has_feature(KVM_FEATURE_STEAL_TIME)) {
 		pv_mmu_ops.flush_tlb_others = kvm_flush_tlb_others;
+		pv_mmu_ops.tlb_remove_table = tlb_remove_table;
+	}
 
 	if (kvm_para_has_feature(KVM_FEATURE_PV_EOI))
 		apic_set_eoi_write(kvm_guest_apic_eoi_write);
diff --git a/arch/x86/kernel/paravirt.c b/arch/x86/kernel/paravirt.c
index 930c88341e4ec875013d99f8c844ecace51088fb..afdb303285f874d481c7de2fbe7fd9f80873a424 100644
--- a/arch/x86/kernel/paravirt.c
+++ b/arch/x86/kernel/paravirt.c
@@ -41,6 +41,7 @@
 #include <asm/tlbflush.h>
 #include <asm/timer.h>
 #include <asm/special_insns.h>
+#include <asm/tlb.h>
 
 /*
  * nop stub, which must not clobber anything *including the stack* to
@@ -409,6 +410,7 @@ struct pv_mmu_ops pv_mmu_ops __ro_after_init = {
 	.flush_tlb_kernel = native_flush_tlb_global,
 	.flush_tlb_one_user = native_flush_tlb_one_user,
 	.flush_tlb_others = native_flush_tlb_others,
+	.tlb_remove_table = (void (*)(struct mmu_gather *, void *))tlb_remove_page,
 
 	.pgd_alloc = __paravirt_pgd_alloc,
 	.pgd_free = paravirt_nop,
diff --git a/arch/x86/mm/pgtable.c b/arch/x86/mm/pgtable.c
index 3ef095c70ae31636fbecd185eac096ef5858d08b..e848a48117856c8e58dcf313339853563f5c6f7f 100644
--- a/arch/x86/mm/pgtable.c
+++ b/arch/x86/mm/pgtable.c
@@ -63,7 +63,7 @@ void ___pte_free_tlb(struct mmu_gather *tlb, struct page *pte)
 {
 	pgtable_page_dtor(pte);
 	paravirt_release_pte(page_to_pfn(pte));
-	tlb_remove_table(tlb, pte);
+	paravirt_tlb_remove_table(tlb, pte);
 }
 
 #if CONFIG_PGTABLE_LEVELS > 2
@@ -79,21 +79,21 @@ void ___pmd_free_tlb(struct mmu_gather *tlb, pmd_t *pmd)
 	tlb->need_flush_all = 1;
 #endif
 	pgtable_pmd_page_dtor(page);
-	tlb_remove_table(tlb, page);
+	paravirt_tlb_remove_table(tlb, page);
 }
 
 #if CONFIG_PGTABLE_LEVELS > 3
 void ___pud_free_tlb(struct mmu_gather *tlb, pud_t *pud)
 {
 	paravirt_release_pud(__pa(pud) >> PAGE_SHIFT);
-	tlb_remove_table(tlb, virt_to_page(pud));
+	paravirt_tlb_remove_table(tlb, virt_to_page(pud));
 }
 
 #if CONFIG_PGTABLE_LEVELS > 4
 void ___p4d_free_tlb(struct mmu_gather *tlb, p4d_t *p4d)
 {
 	paravirt_release_p4d(__pa(p4d) >> PAGE_SHIFT);
-	tlb_remove_table(tlb, virt_to_page(p4d));
+	paravirt_tlb_remove_table(tlb, virt_to_page(p4d));
 }
 #endif	/* CONFIG_PGTABLE_LEVELS > 4 */
 #endif	/* CONFIG_PGTABLE_LEVELS > 3 */
diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c
index 752dbf4e0e5071b9a71908822231ceda5940dabf..9517d1b2a2810817907640c6c2dde698c62b3e79 100644
--- a/arch/x86/mm/tlb.c
+++ b/arch/x86/mm/tlb.c
@@ -7,7 +7,6 @@
 #include <linux/export.h>
 #include <linux/cpu.h>
 #include <linux/debugfs.h>
-#include <linux/gfp.h>
 
 #include <asm/tlbflush.h>
 #include <asm/mmu_context.h>
@@ -186,11 +185,8 @@ void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 {
 	struct mm_struct *real_prev = this_cpu_read(cpu_tlbstate.loaded_mm);
 	u16 prev_asid = this_cpu_read(cpu_tlbstate.loaded_mm_asid);
-	bool was_lazy = this_cpu_read(cpu_tlbstate.is_lazy);
 	unsigned cpu = smp_processor_id();
 	u64 next_tlb_gen;
-	bool need_flush;
-	u16 new_asid;
 
 	/*
 	 * NB: The scheduler will call us with prev == next when switching
@@ -244,41 +240,20 @@ void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 			   next->context.ctx_id);
 
 		/*
-		 * Even in lazy TLB mode, the CPU should stay set in the
-		 * mm_cpumask. The TLB shootdown code can figure out from
-		 * from cpu_tlbstate.is_lazy whether or not to send an IPI.
+		 * We don't currently support having a real mm loaded without
+		 * our cpu set in mm_cpumask().  We have all the bookkeeping
+		 * in place to figure out whether we would need to flush
+		 * if our cpu were cleared in mm_cpumask(), but we don't
+		 * currently use it.
 		 */
 		if (WARN_ON_ONCE(real_prev != &init_mm &&
 				 !cpumask_test_cpu(cpu, mm_cpumask(next))))
 			cpumask_set_cpu(cpu, mm_cpumask(next));
 
-		/*
-		 * If the CPU is not in lazy TLB mode, we are just switching
-		 * from one thread in a process to another thread in the same
-		 * process. No TLB flush required.
-		 */
-		if (!was_lazy)
-			return;
-
-		/*
-		 * Read the tlb_gen to check whether a flush is needed.
-		 * If the TLB is up to date, just use it.
-		 * The barrier synchronizes with the tlb_gen increment in
-		 * the TLB shootdown code.
-		 */
-		smp_mb();
-		next_tlb_gen = atomic64_read(&next->context.tlb_gen);
-		if (this_cpu_read(cpu_tlbstate.ctxs[prev_asid].tlb_gen) ==
-				next_tlb_gen)
-			return;
-
-		/*
-		 * TLB contents went out of date while we were in lazy
-		 * mode. Fall through to the TLB switching code below.
-		 */
-		new_asid = prev_asid;
-		need_flush = true;
+		return;
 	} else {
+		u16 new_asid;
+		bool need_flush;
 		u64 last_ctx_id = this_cpu_read(cpu_tlbstate.last_ctx_id);
 
 		/*
@@ -329,41 +304,41 @@ void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 		next_tlb_gen = atomic64_read(&next->context.tlb_gen);
 
 		choose_new_asid(next, next_tlb_gen, &new_asid, &need_flush);
-	}
 
-	if (need_flush) {
-		this_cpu_write(cpu_tlbstate.ctxs[new_asid].ctx_id, next->context.ctx_id);
-		this_cpu_write(cpu_tlbstate.ctxs[new_asid].tlb_gen, next_tlb_gen);
-		load_new_mm_cr3(next->pgd, new_asid, true);
+		if (need_flush) {
+			this_cpu_write(cpu_tlbstate.ctxs[new_asid].ctx_id, next->context.ctx_id);
+			this_cpu_write(cpu_tlbstate.ctxs[new_asid].tlb_gen, next_tlb_gen);
+			load_new_mm_cr3(next->pgd, new_asid, true);
+
+			/*
+			 * NB: This gets called via leave_mm() in the idle path
+			 * where RCU functions differently.  Tracing normally
+			 * uses RCU, so we need to use the _rcuidle variant.
+			 *
+			 * (There is no good reason for this.  The idle code should
+			 *  be rearranged to call this before rcu_idle_enter().)
+			 */
+			trace_tlb_flush_rcuidle(TLB_FLUSH_ON_TASK_SWITCH, TLB_FLUSH_ALL);
+		} else {
+			/* The new ASID is already up to date. */
+			load_new_mm_cr3(next->pgd, new_asid, false);
+
+			/* See above wrt _rcuidle. */
+			trace_tlb_flush_rcuidle(TLB_FLUSH_ON_TASK_SWITCH, 0);
+		}
 
 		/*
-		 * NB: This gets called via leave_mm() in the idle path
-		 * where RCU functions differently.  Tracing normally
-		 * uses RCU, so we need to use the _rcuidle variant.
-		 *
-		 * (There is no good reason for this.  The idle code should
-		 *  be rearranged to call this before rcu_idle_enter().)
+		 * Record last user mm's context id, so we can avoid
+		 * flushing branch buffer with IBPB if we switch back
+		 * to the same user.
 		 */
-		trace_tlb_flush_rcuidle(TLB_FLUSH_ON_TASK_SWITCH, TLB_FLUSH_ALL);
-	} else {
-		/* The new ASID is already up to date. */
-		load_new_mm_cr3(next->pgd, new_asid, false);
+		if (next != &init_mm)
+			this_cpu_write(cpu_tlbstate.last_ctx_id, next->context.ctx_id);
 
-		/* See above wrt _rcuidle. */
-		trace_tlb_flush_rcuidle(TLB_FLUSH_ON_TASK_SWITCH, 0);
+		this_cpu_write(cpu_tlbstate.loaded_mm, next);
+		this_cpu_write(cpu_tlbstate.loaded_mm_asid, new_asid);
 	}
 
-	/*
-	 * Record last user mm's context id, so we can avoid
-	 * flushing branch buffer with IBPB if we switch back
-	 * to the same user.
-	 */
-	if (next != &init_mm)
-		this_cpu_write(cpu_tlbstate.last_ctx_id, next->context.ctx_id);
-
-	this_cpu_write(cpu_tlbstate.loaded_mm, next);
-	this_cpu_write(cpu_tlbstate.loaded_mm_asid, new_asid);
-
 	load_mm_cr4(next);
 	switch_ldt(real_prev, next);
 }
@@ -386,7 +361,20 @@ void enter_lazy_tlb(struct mm_struct *mm, struct task_struct *tsk)
 	if (this_cpu_read(cpu_tlbstate.loaded_mm) == &init_mm)
 		return;
 
-	this_cpu_write(cpu_tlbstate.is_lazy, true);
+	if (tlb_defer_switch_to_init_mm()) {
+		/*
+		 * There's a significant optimization that may be possible
+		 * here.  We have accurate enough TLB flush tracking that we
+		 * don't need to maintain coherence of TLB per se when we're
+		 * lazy.  We do, however, need to maintain coherence of
+		 * paging-structure caches.  We could, in principle, leave our
+		 * old mm loaded and only switch to init_mm when
+		 * tlb_remove_page() happens.
+		 */
+		this_cpu_write(cpu_tlbstate.is_lazy, true);
+	} else {
+		switch_mm(NULL, &init_mm, NULL);
+	}
 }
 
 /*
@@ -473,9 +461,6 @@ static void flush_tlb_func_common(const struct flush_tlb_info *f,
 		 * paging-structure cache to avoid speculatively reading
 		 * garbage into our TLB.  Since switching to init_mm is barely
 		 * slower than a minimal flush, just switch to init_mm.
-		 *
-		 * This should be rare, with native_flush_tlb_others skipping
-		 * IPIs to lazy TLB mode CPUs.
 		 */
 		switch_mm_irqs_off(NULL, &init_mm, NULL);
 		return;
@@ -582,9 +567,6 @@ static void flush_tlb_func_remote(void *info)
 void native_flush_tlb_others(const struct cpumask *cpumask,
 			     const struct flush_tlb_info *info)
 {
-	cpumask_var_t lazymask;
-	unsigned int cpu;
-
 	count_vm_tlb_event(NR_TLB_REMOTE_FLUSH);
 	if (info->end == TLB_FLUSH_ALL)
 		trace_tlb_flush(TLB_REMOTE_SEND_IPI, TLB_FLUSH_ALL);
@@ -608,6 +590,8 @@ void native_flush_tlb_others(const struct cpumask *cpumask,
 		 * that UV should be updated so that smp_call_function_many(),
 		 * etc, are optimal on UV.
 		 */
+		unsigned int cpu;
+
 		cpu = smp_processor_id();
 		cpumask = uv_flush_tlb_others(cpumask, info);
 		if (cpumask)
@@ -615,29 +599,8 @@ void native_flush_tlb_others(const struct cpumask *cpumask,
 					       (void *)info, 1);
 		return;
 	}
-
-	/*
-	 * A temporary cpumask is used in order to skip sending IPIs
-	 * to CPUs in lazy TLB state, while keeping them in mm_cpumask(mm).
-	 * If the allocation fails, simply IPI every CPU in mm_cpumask.
-	 */
-	if (!alloc_cpumask_var(&lazymask, GFP_ATOMIC)) {
-		smp_call_function_many(cpumask, flush_tlb_func_remote,
-			       (void *)info, 1);
-		return;
-	}
-
-	cpumask_copy(lazymask, cpumask);
-
-	for_each_cpu(cpu, lazymask) {
-		if (per_cpu(cpu_tlbstate.is_lazy, cpu))
-			cpumask_clear_cpu(cpu, lazymask);
-	}
-
-	smp_call_function_many(lazymask, flush_tlb_func_remote,
+	smp_call_function_many(cpumask, flush_tlb_func_remote,
 			       (void *)info, 1);
-
-	free_cpumask_var(lazymask);
 }
 
 /*
@@ -690,68 +653,6 @@ void flush_tlb_mm_range(struct mm_struct *mm, unsigned long start,
 	put_cpu();
 }
 
-void tlb_flush_remove_tables_local(void *arg)
-{
-	struct mm_struct *mm = arg;
-
-	if (this_cpu_read(cpu_tlbstate.loaded_mm) == mm &&
-			this_cpu_read(cpu_tlbstate.is_lazy)) {
-		/*
-		 * We're in lazy mode.  We need to at least flush our
-		 * paging-structure cache to avoid speculatively reading
-		 * garbage into our TLB.  Since switching to init_mm is barely
-		 * slower than a minimal flush, just switch to init_mm.
-		 */
-		switch_mm_irqs_off(NULL, &init_mm, NULL);
-	}
-}
-
-static void mm_fill_lazy_tlb_cpu_mask(struct mm_struct *mm,
-				      struct cpumask *lazy_cpus)
-{
-	int cpu;
-
-	for_each_cpu(cpu, mm_cpumask(mm)) {
-		if (!per_cpu(cpu_tlbstate.is_lazy, cpu))
-			cpumask_set_cpu(cpu, lazy_cpus);
-	}
-}
-
-void tlb_flush_remove_tables(struct mm_struct *mm)
-{
-	int cpu = get_cpu();
-	cpumask_var_t lazy_cpus;
-
-	if (cpumask_any_but(mm_cpumask(mm), cpu) >= nr_cpu_ids) {
-		put_cpu();
-		return;
-	}
-
-	if (!zalloc_cpumask_var(&lazy_cpus, GFP_ATOMIC)) {
-		/*
-		 * If the cpumask allocation fails, do a brute force flush
-		 * on all the CPUs that have this mm loaded.
-		 */
-		smp_call_function_many(mm_cpumask(mm),
-				tlb_flush_remove_tables_local, (void *)mm, 1);
-		put_cpu();
-		return;
-	}
-
-	/*
-	 * CPUs with !is_lazy either received a TLB flush IPI while the user
-	 * pages in this address range were unmapped, or have context switched
-	 * and reloaded %CR3 since then.
-	 *
-	 * Shootdown IPIs at page table freeing time only need to be sent to
-	 * CPUs that may have out of date TLB contents.
-	 */
-	mm_fill_lazy_tlb_cpu_mask(mm, lazy_cpus);
-	smp_call_function_many(lazy_cpus,
-				tlb_flush_remove_tables_local, (void *)mm, 1);
-	free_cpumask_var(lazy_cpus);
-	put_cpu();
-}
 
 static void do_flush_tlb_all(void *info)
 {
diff --git a/arch/x86/xen/mmu_pv.c b/arch/x86/xen/mmu_pv.c
index 9e7012858420c99174002ae47a2d79e81aaad8df..45b700ac5fe7e0685578597ee4189b8124b403c1 100644
--- a/arch/x86/xen/mmu_pv.c
+++ b/arch/x86/xen/mmu_pv.c
@@ -67,6 +67,7 @@
 #include <asm/init.h>
 #include <asm/pat.h>
 #include <asm/smp.h>
+#include <asm/tlb.h>
 
 #include <asm/xen/hypercall.h>
 #include <asm/xen/hypervisor.h>
@@ -2399,6 +2400,7 @@ static const struct pv_mmu_ops xen_mmu_ops __initconst = {
 	.flush_tlb_kernel = xen_flush_tlb,
 	.flush_tlb_one_user = xen_flush_tlb_one_user,
 	.flush_tlb_others = xen_flush_tlb_others,
+	.tlb_remove_table = tlb_remove_table,
 
 	.pgd_alloc = xen_pgd_alloc,
 	.pgd_free = xen_pgd_free,
diff --git a/include/asm-generic/tlb.h b/include/asm-generic/tlb.h
index e811ef7b835006e232bad707f02fd5dfdc7e0691..b3353e21f3b3ec95220e1706bae89d3969d9e918 100644
--- a/include/asm-generic/tlb.h
+++ b/include/asm-generic/tlb.h
@@ -15,6 +15,7 @@
 #ifndef _ASM_GENERIC__TLB_H
 #define _ASM_GENERIC__TLB_H
 
+#include <linux/mmu_notifier.h>
 #include <linux/swap.h>
 #include <asm/pgalloc.h>
 #include <asm/tlbflush.h>
@@ -138,6 +139,16 @@ static inline void __tlb_reset_range(struct mmu_gather *tlb)
 	}
 }
 
+static inline void tlb_flush_mmu_tlbonly(struct mmu_gather *tlb)
+{
+	if (!tlb->end)
+		return;
+
+	tlb_flush(tlb);
+	mmu_notifier_invalidate_range(tlb->mm, tlb->start, tlb->end);
+	__tlb_reset_range(tlb);
+}
+
 static inline void tlb_remove_page_size(struct mmu_gather *tlb,
 					struct page *page, int page_size)
 {
@@ -186,10 +197,8 @@ static inline void tlb_remove_check_page_size_change(struct mmu_gather *tlb,
 
 #define __tlb_end_vma(tlb, vma)					\
 	do {							\
-		if (!tlb->fullmm && tlb->end) {			\
-			tlb_flush(tlb);				\
-			__tlb_reset_range(tlb);			\
-		}						\
+		if (!tlb->fullmm)				\
+			tlb_flush_mmu_tlbonly(tlb);		\
 	} while (0)
 
 #ifndef tlb_end_vma
@@ -303,14 +312,4 @@ static inline void tlb_remove_check_page_size_change(struct mmu_gather *tlb,
 
 #define tlb_migrate_finish(mm) do {} while (0)
 
-/*
- * Used to flush the TLB when page tables are removed, when lazy
- * TLB mode may cause a CPU to retain intermediate translations
- * pointing to about-to-be-freed page table memory.
- */
-#ifndef HAVE_TLB_FLUSH_REMOVE_TABLES
-#define tlb_flush_remove_tables(mm) do {} while (0)
-#define tlb_flush_remove_tables_local(mm) do {} while (0)
-#endif
-
 #endif /* _ASM_GENERIC__TLB_H */
diff --git a/mm/memory.c b/mm/memory.c
index 19f47d7b9b86795ed8b116cffceedd1bb625675d..3ff4394a2e1b01ae9a4ace4b0cede070f8545645 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -238,23 +238,13 @@ void arch_tlb_gather_mmu(struct mmu_gather *tlb, struct mm_struct *mm,
 	__tlb_reset_range(tlb);
 }
 
-static void tlb_flush_mmu_tlbonly(struct mmu_gather *tlb)
+static void tlb_flush_mmu_free(struct mmu_gather *tlb)
 {
-	if (!tlb->end)
-		return;
+	struct mmu_gather_batch *batch;
 
-	tlb_flush(tlb);
-	mmu_notifier_invalidate_range(tlb->mm, tlb->start, tlb->end);
 #ifdef CONFIG_HAVE_RCU_TABLE_FREE
 	tlb_table_flush(tlb);
 #endif
-	__tlb_reset_range(tlb);
-}
-
-static void tlb_flush_mmu_free(struct mmu_gather *tlb)
-{
-	struct mmu_gather_batch *batch;
-
 	for (batch = &tlb->local; batch && batch->nr; batch = batch->next) {
 		free_pages_and_swap_cache(batch->pages, batch->nr);
 		batch->nr = 0;
@@ -326,20 +316,31 @@ bool __tlb_remove_page_size(struct mmu_gather *tlb, struct page *page, int page_
 
 #ifdef CONFIG_HAVE_RCU_TABLE_FREE
 
-static void tlb_remove_table_smp_sync(void *arg)
+/*
+ * See the comment near struct mmu_table_batch.
+ */
+
+/*
+ * If we want tlb_remove_table() to imply TLB invalidates.
+ */
+static inline void tlb_table_invalidate(struct mmu_gather *tlb)
 {
-	struct mm_struct __maybe_unused *mm = arg;
+#ifdef CONFIG_HAVE_RCU_TABLE_INVALIDATE
 	/*
-	 * On most architectures this does nothing. Simply delivering the
-	 * interrupt is enough to prevent races with software page table
-	 * walking like that done in get_user_pages_fast.
-	 *
-	 * See the comment near struct mmu_table_batch.
+	 * Invalidate page-table caches used by hardware walkers. Then we still
+	 * need to RCU-sched wait while freeing the pages because software
+	 * walkers can still be in-flight.
 	 */
-	tlb_flush_remove_tables_local(mm);
+	tlb_flush_mmu_tlbonly(tlb);
+#endif
+}
+
+static void tlb_remove_table_smp_sync(void *arg)
+{
+	/* Simply deliver the interrupt */
 }
 
-static void tlb_remove_table_one(void *table, struct mmu_gather *tlb)
+static void tlb_remove_table_one(void *table)
 {
 	/*
 	 * This isn't an RCU grace period and hence the page-tables cannot be
@@ -348,7 +349,7 @@ static void tlb_remove_table_one(void *table, struct mmu_gather *tlb)
 	 * It is however sufficient for software page-table walkers that rely on
 	 * IRQ disabling. See the comment near struct mmu_table_batch.
 	 */
-	smp_call_function(tlb_remove_table_smp_sync, tlb->mm, 1);
+	smp_call_function(tlb_remove_table_smp_sync, NULL, 1);
 	__tlb_remove_table(table);
 }
 
@@ -369,9 +370,8 @@ void tlb_table_flush(struct mmu_gather *tlb)
 {
 	struct mmu_table_batch **batch = &tlb->batch;
 
-	tlb_flush_remove_tables(tlb->mm);
-
 	if (*batch) {
+		tlb_table_invalidate(tlb);
 		call_rcu_sched(&(*batch)->rcu, tlb_remove_table_rcu);
 		*batch = NULL;
 	}
@@ -381,23 +381,16 @@ void tlb_remove_table(struct mmu_gather *tlb, void *table)
 {
 	struct mmu_table_batch **batch = &tlb->batch;
 
-	/*
-	 * When there's less then two users of this mm there cannot be a
-	 * concurrent page-table walk.
-	 */
-	if (atomic_read(&tlb->mm->mm_users) < 2) {
-		__tlb_remove_table(table);
-		return;
-	}
-
 	if (*batch == NULL) {
 		*batch = (struct mmu_table_batch *)__get_free_page(GFP_NOWAIT | __GFP_NOWARN);
 		if (*batch == NULL) {
-			tlb_remove_table_one(table, tlb);
+			tlb_table_invalidate(tlb);
+			tlb_remove_table_one(table);
 			return;
 		}
 		(*batch)->nr = 0;
 	}
+
 	(*batch)->tables[(*batch)->nr++] = table;
 	if ((*batch)->nr == MAX_TABLE_BATCH)
 		tlb_table_flush(tlb);