diff --git a/arch/x86/kernel/amd_iommu.c b/arch/x86/kernel/amd_iommu.c index 01c68c38840d8b8997e5d801e508dd20185605f0..695e0fc41b108389242ceb41fbf872af004f75df 100644 --- a/arch/x86/kernel/amd_iommu.c +++ b/arch/x86/kernel/amd_iommu.c @@ -645,6 +645,18 @@ static void set_device_domain(struct amd_iommu *iommu, * *****************************************************************************/ +/* + * This function checks if the driver got a valid device from the caller to + * avoid dereferencing invalid pointers. + */ +static bool check_device(struct device *dev) +{ + if (!dev || !dev->dma_mask) + return false; + + return true; +} + /* * In the dma_ops path we only have the struct device. This function * finds the corresponding IOMMU, the protection domain and the @@ -661,18 +673,19 @@ static int get_device_resources(struct device *dev, struct pci_dev *pcidev; u16 _bdf; - BUG_ON(!dev || dev->bus != &pci_bus_type || !dev->dma_mask); + *iommu = NULL; + *domain = NULL; + *bdf = 0xffff; + + if (dev->bus != &pci_bus_type) + return 0; pcidev = to_pci_dev(dev); _bdf = calc_devid(pcidev->bus->number, pcidev->devfn); /* device not translated by any IOMMU in the system? */ - if (_bdf > amd_iommu_last_bdf) { - *iommu = NULL; - *domain = NULL; - *bdf = 0xffff; + if (_bdf > amd_iommu_last_bdf) return 0; - } *bdf = amd_iommu_alias_table[_bdf]; @@ -826,6 +839,9 @@ static dma_addr_t map_single(struct device *dev, phys_addr_t paddr, u16 devid; dma_addr_t addr; + if (!check_device(dev)) + return bad_dma_address; + get_device_resources(dev, &iommu, &domain, &devid); if (iommu == NULL || domain == NULL) @@ -860,7 +876,8 @@ static void unmap_single(struct device *dev, dma_addr_t dma_addr, struct protection_domain *domain; u16 devid; - if (!get_device_resources(dev, &iommu, &domain, &devid)) + if (!check_device(dev) || + !get_device_resources(dev, &iommu, &domain, &devid)) /* device not handled by any AMD IOMMU */ return; @@ -910,6 +927,9 @@ static int map_sg(struct device *dev, struct scatterlist *sglist, phys_addr_t paddr; int mapped_elems = 0; + if (!check_device(dev)) + return 0; + get_device_resources(dev, &iommu, &domain, &devid); if (!iommu || !domain) @@ -967,7 +987,8 @@ static void unmap_sg(struct device *dev, struct scatterlist *sglist, u16 devid; int i; - if (!get_device_resources(dev, &iommu, &domain, &devid)) + if (!check_device(dev) || + !get_device_resources(dev, &iommu, &domain, &devid)) return; spin_lock_irqsave(&domain->lock, flags); @@ -999,6 +1020,9 @@ static void *alloc_coherent(struct device *dev, size_t size, u16 devid; phys_addr_t paddr; + if (!check_device(dev)) + return NULL; + virt_addr = (void *)__get_free_pages(flag, get_order(size)); if (!virt_addr) return 0; @@ -1047,6 +1071,9 @@ static void free_coherent(struct device *dev, size_t size, struct protection_domain *domain; u16 devid; + if (!check_device(dev)) + return; + get_device_resources(dev, &iommu, &domain, &devid); if (!iommu || !domain)