diff --git a/arch/x86/kernel/amd_iommu.c b/arch/x86/kernel/amd_iommu.c index f4747fe70aaa75f69f1a18878ce6be14cb2124d4..aab9125ac0b2c23af0adfd4ffca959eb87460957 100644 --- a/arch/x86/kernel/amd_iommu.c +++ b/arch/x86/kernel/amd_iommu.c @@ -798,3 +798,77 @@ static void unmap_sg(struct device *dev, struct scatterlist *sglist, spin_unlock_irqrestore(&domain->lock, flags); } +static void *alloc_coherent(struct device *dev, size_t size, + dma_addr_t *dma_addr, gfp_t flag) +{ + unsigned long flags; + void *virt_addr; + struct amd_iommu *iommu; + struct protection_domain *domain; + u16 devid; + phys_addr_t paddr; + + virt_addr = (void *)__get_free_pages(flag, get_order(size)); + if (!virt_addr) + return 0; + + memset(virt_addr, 0, size); + paddr = virt_to_phys(virt_addr); + + get_device_resources(dev, &iommu, &domain, &devid); + + if (!iommu || !domain) { + *dma_addr = (dma_addr_t)paddr; + return virt_addr; + } + + spin_lock_irqsave(&domain->lock, flags); + + *dma_addr = __map_single(dev, iommu, domain->priv, paddr, + size, DMA_BIDIRECTIONAL); + + if (*dma_addr == bad_dma_address) { + free_pages((unsigned long)virt_addr, get_order(size)); + virt_addr = NULL; + goto out; + } + + if (iommu_has_npcache(iommu)) + iommu_flush_pages(iommu, domain->id, *dma_addr, size); + + if (iommu->need_sync) + iommu_completion_wait(iommu); + +out: + spin_unlock_irqrestore(&domain->lock, flags); + + return virt_addr; +} + +static void free_coherent(struct device *dev, size_t size, + void *virt_addr, dma_addr_t dma_addr) +{ + unsigned long flags; + struct amd_iommu *iommu; + struct protection_domain *domain; + u16 devid; + + get_device_resources(dev, &iommu, &domain, &devid); + + if (!iommu || !domain) + goto free_mem; + + spin_lock_irqsave(&domain->lock, flags); + + __unmap_single(iommu, domain->priv, dma_addr, size, DMA_BIDIRECTIONAL); + iommu_flush_pages(iommu, domain->id, dma_addr, size); + + if (iommu->need_sync) + iommu_completion_wait(iommu); + + spin_unlock_irqrestore(&domain->lock, flags); + +free_mem: + free_pages((unsigned long)virt_addr, get_order(size)); +} +