diff --git a/drivers/iommu/amd_iommu_init.c b/drivers/iommu/amd_iommu_init.c
index 182b80ba79c56868ecdb83fb2120a61e39abb408..8b026bfe05d1b545d33433d7d15545039feb4b35 100644
--- a/drivers/iommu/amd_iommu_init.c
+++ b/drivers/iommu/amd_iommu_init.c
@@ -196,6 +196,8 @@ static u32 rlookup_table_size;	/* size if the rlookup table */
  */
 extern void iommu_flush_all_caches(struct amd_iommu *iommu);
 
+static int __init amd_iommu_enable_interrupts(void);
+
 static inline void update_last_devid(u16 devid)
 {
 	if (devid > amd_iommu_last_bdf)
@@ -1383,7 +1385,6 @@ static void enable_iommus(void)
 		iommu_enable_ppr_log(iommu);
 		iommu_enable_gt(iommu);
 		iommu_set_exclusion_range(iommu);
-		iommu_init_msi(iommu);
 		iommu_enable(iommu);
 		iommu_flush_all_caches(iommu);
 	}
@@ -1411,6 +1412,8 @@ static void amd_iommu_resume(void)
 
 	/* re-load the hardware */
 	enable_iommus();
+
+	amd_iommu_enable_interrupts();
 }
 
 static int amd_iommu_suspend(void)
@@ -1595,6 +1598,21 @@ int __init amd_iommu_init_hardware(void)
 	return ret;
 }
 
+static int __init amd_iommu_enable_interrupts(void)
+{
+	struct amd_iommu *iommu;
+	int ret = 0;
+
+	for_each_iommu(iommu) {
+		ret = iommu_init_msi(iommu);
+		if (ret)
+			goto out;
+	}
+
+out:
+	return ret;
+}
+
 /*
  * This is the core init function for AMD IOMMU hardware in the system.
  * This function is called from the generic x86 DMA layer initialization
@@ -1612,6 +1630,10 @@ static int __init amd_iommu_init(void)
 	if (ret)
 		goto out;
 
+	ret = amd_iommu_enable_interrupts();
+	if (ret)
+		goto free;
+
 	if (iommu_pass_through)
 		ret = amd_iommu_init_passthrough();
 	else