diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c index 18954846dff6a51501f142f2a8a1da850a3f8ca2..943c41eef74a8911f0f567b8471c7cfa6a2d8268 100644 --- a/drivers/vfio/vfio_iommu_type1.c +++ b/drivers/vfio/vfio_iommu_type1.c @@ -347,7 +347,7 @@ static int put_pfn(unsigned long pfn, int prot) } static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr, - int prot, unsigned long *pfn) + int prot, unsigned long *pfn, bool handle_mmap_sem) { struct page *page[1]; struct vm_area_struct *vma; @@ -358,7 +358,8 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr, if (prot & IOMMU_WRITE) flags |= FOLL_WRITE; - down_read(&mm->mmap_sem); + if (handle_mmap_sem) + down_read(&mm->mmap_sem); if (mm == current->mm) { ret = get_user_pages_longterm(vaddr, 1, flags, page, vmas); } else { @@ -376,14 +377,16 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr, put_page(page[0]); } } - up_read(&mm->mmap_sem); + if (handle_mmap_sem) + up_read(&mm->mmap_sem); if (ret == 1) { *pfn = page_to_pfn(page[0]); return 0; } - down_read(&mm->mmap_sem); + if (handle_mmap_sem) + down_read(&mm->mmap_sem); vma = find_vma_intersection(mm, vaddr, vaddr + 1); @@ -393,7 +396,8 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr, ret = 0; } - up_read(&mm->mmap_sem); + if (handle_mmap_sem) + up_read(&mm->mmap_sem); return ret; } @@ -415,9 +419,12 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr, if (!mm) return -ENODEV; - ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base); - if (ret) + down_read(&mm->mmap_sem); + ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base, false); + if (ret) { + up_read(&mm->mmap_sem); return ret; + } pinned++; rsvd = is_invalid_reserved_pfn(*pfn_base); @@ -432,6 +439,7 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr, put_pfn(*pfn_base, dma->prot); pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__, limit << PAGE_SHIFT); + up_read(&mm->mmap_sem); return -ENOMEM; } lock_acct++; @@ -443,7 +451,7 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr, /* Lock all the consecutive pages from pfn_base */ for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage; pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) { - ret = vaddr_get_pfn(mm, vaddr, dma->prot, &pfn); + ret = vaddr_get_pfn(mm, vaddr, dma->prot, &pfn, false); if (ret) break; @@ -460,6 +468,7 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr, pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__, limit << PAGE_SHIFT); ret = -ENOMEM; + up_read(&mm->mmap_sem); goto unpin_out; } lock_acct++; @@ -467,6 +476,7 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr, } out: + up_read(&mm->mmap_sem); ret = vfio_lock_acct(dma, lock_acct, false); unpin_out: @@ -513,7 +523,7 @@ static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr, if (!mm) return -ENODEV; - ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base); + ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base, true); if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) { ret = vfio_lock_acct(dma, 1, true); if (ret) {