diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c index 414bb2e8c7bd68147d4ef48654b75c049ad7231d..306778dcbb9e579495d10cff0d78e18aaaf6e39b 100644 --- a/drivers/vfio/vfio_iommu_type1.c +++ b/drivers/vfio/vfio_iommu_type1.c @@ -347,7 +347,7 @@ static int put_pfn(unsigned long pfn, int prot) } static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr, - int prot, unsigned long *pfn, bool handle_mmap_sem) + int prot, unsigned long *pfn) { struct page *page[1]; struct vm_area_struct *vma; @@ -358,8 +358,7 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr, if (prot & IOMMU_WRITE) flags |= FOLL_WRITE; - if (handle_mmap_sem) - down_read(&mm->mmap_sem); + down_read(&mm->mmap_sem); if (mm == current->mm) { ret = get_user_pages_longterm(vaddr, 1, flags, page, vmas); } else { @@ -377,16 +376,14 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr, put_page(page[0]); } } - if (handle_mmap_sem) - up_read(&mm->mmap_sem); + up_read(&mm->mmap_sem); if (ret == 1) { *pfn = page_to_pfn(page[0]); return 0; } - if (handle_mmap_sem) - down_read(&mm->mmap_sem); + down_read(&mm->mmap_sem); vma = find_vma_intersection(mm, vaddr, vaddr + 1); @@ -396,8 +393,7 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr, ret = 0; } - if (handle_mmap_sem) - up_read(&mm->mmap_sem); + up_read(&mm->mmap_sem); return ret; } @@ -419,12 +415,9 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr, if (!mm) return -ENODEV; - down_read(&mm->mmap_sem); - ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base, false); - if (ret) { - up_read(&mm->mmap_sem); + ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base); + if (ret) return ret; - } pinned++; rsvd = is_invalid_reserved_pfn(*pfn_base); @@ -439,7 +432,6 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr, put_pfn(*pfn_base, dma->prot); pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__, limit << PAGE_SHIFT); - up_read(&mm->mmap_sem); return -ENOMEM; } lock_acct++; @@ -451,7 +443,7 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr, /* Lock all the consecutive pages from pfn_base */ for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage; pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) { - ret = vaddr_get_pfn(mm, vaddr, dma->prot, &pfn, false); + ret = vaddr_get_pfn(mm, vaddr, dma->prot, &pfn); if (ret) break; @@ -468,7 +460,6 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr, pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__, limit << PAGE_SHIFT); ret = -ENOMEM; - up_read(&mm->mmap_sem); goto unpin_out; } lock_acct++; @@ -476,7 +467,6 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr, } out: - up_read(&mm->mmap_sem); ret = vfio_lock_acct(dma, lock_acct, false); unpin_out: @@ -523,7 +513,7 @@ static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr, if (!mm) return -ENODEV; - ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base, true); + ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base); if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) { ret = vfio_lock_acct(dma, 1, true); if (ret) {