diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index 18954846dff6a51501f142f2a8a1da850a3f8ca2..943c41eef74a8911f0f567b8471c7cfa6a2d8268 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -347,7 +347,7 @@ static int put_pfn(unsigned long pfn, int prot)
 }
 
 static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
-			 int prot, unsigned long *pfn)
+			 int prot, unsigned long *pfn, bool handle_mmap_sem)
 {
 	struct page *page[1];
 	struct vm_area_struct *vma;
@@ -358,7 +358,8 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
 	if (prot & IOMMU_WRITE)
 		flags |= FOLL_WRITE;
 
-	down_read(&mm->mmap_sem);
+	if (handle_mmap_sem)
+		down_read(&mm->mmap_sem);
 	if (mm == current->mm) {
 		ret = get_user_pages_longterm(vaddr, 1, flags, page, vmas);
 	} else {
@@ -376,14 +377,16 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
 			put_page(page[0]);
 		}
 	}
-	up_read(&mm->mmap_sem);
+	if (handle_mmap_sem)
+		up_read(&mm->mmap_sem);
 
 	if (ret == 1) {
 		*pfn = page_to_pfn(page[0]);
 		return 0;
 	}
 
-	down_read(&mm->mmap_sem);
+	if (handle_mmap_sem)
+		down_read(&mm->mmap_sem);
 
 	vma = find_vma_intersection(mm, vaddr, vaddr + 1);
 
@@ -393,7 +396,8 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
 			ret = 0;
 	}
 
-	up_read(&mm->mmap_sem);
+	if (handle_mmap_sem)
+		up_read(&mm->mmap_sem);
 	return ret;
 }
 
@@ -415,9 +419,12 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 	if (!mm)
 		return -ENODEV;
 
-	ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
-	if (ret)
+	down_read(&mm->mmap_sem);
+	ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base, false);
+	if (ret) {
+		up_read(&mm->mmap_sem);
 		return ret;
+	}
 
 	pinned++;
 	rsvd = is_invalid_reserved_pfn(*pfn_base);
@@ -432,6 +439,7 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 			put_pfn(*pfn_base, dma->prot);
 			pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
 					limit << PAGE_SHIFT);
+			up_read(&mm->mmap_sem);
 			return -ENOMEM;
 		}
 		lock_acct++;
@@ -443,7 +451,7 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 	/* Lock all the consecutive pages from pfn_base */
 	for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
 	     pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
-		ret = vaddr_get_pfn(mm, vaddr, dma->prot, &pfn);
+		ret = vaddr_get_pfn(mm, vaddr, dma->prot, &pfn, false);
 		if (ret)
 			break;
 
@@ -460,6 +468,7 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 				pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
 					__func__, limit << PAGE_SHIFT);
 				ret = -ENOMEM;
+				up_read(&mm->mmap_sem);
 				goto unpin_out;
 			}
 			lock_acct++;
@@ -467,6 +476,7 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 	}
 
 out:
+	up_read(&mm->mmap_sem);
 	ret = vfio_lock_acct(dma, lock_acct, false);
 
 unpin_out:
@@ -513,7 +523,7 @@ static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
 	if (!mm)
 		return -ENODEV;
 
-	ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
+	ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base, true);
 	if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
 		ret = vfio_lock_acct(dma, 1, true);
 		if (ret) {