diff --git a/arch/x86/events/amd/iommu.c b/arch/x86/events/amd/iommu.c
index 1c1a7e45dc64f1e941c75d4bdcd7e7e23e578a19..913745f1419baf994f38e64ab75cfe9d1545f119 100644
--- a/arch/x86/events/amd/iommu.c
+++ b/arch/x86/events/amd/iommu.c
@@ -19,8 +19,6 @@
 #include "../perf_event.h"
 #include "iommu.h"
 
-#define COUNTER_SHIFT		16
-
 /* iommu pmu conf masks */
 #define GET_CSOURCE(x)     ((x)->conf & 0xFFULL)
 #define GET_DEVID(x)       (((x)->conf >> 8)  & 0xFFFFULL)
@@ -286,22 +284,31 @@ static void perf_iommu_start(struct perf_event *event, int flags)
 	WARN_ON_ONCE(!(hwc->state & PERF_HES_UPTODATE));
 	hwc->state = 0;
 
+	/*
+	 * To account for power-gating, which prevents write to
+	 * the counter, we need to enable the counter
+	 * before setting up counter register.
+	 */
+	perf_iommu_enable_event(event);
+
 	if (flags & PERF_EF_RELOAD) {
-		u64 prev_raw_count = local64_read(&hwc->prev_count);
+		u64 count = 0;
 		struct amd_iommu *iommu = perf_event_2_iommu(event);
 
+		/*
+		 * Since the IOMMU PMU only support counting mode,
+		 * the counter always start with value zero.
+		 */
 		amd_iommu_pc_set_reg(iommu, hwc->iommu_bank, hwc->iommu_cntr,
-				     IOMMU_PC_COUNTER_REG, &prev_raw_count);
+				     IOMMU_PC_COUNTER_REG, &count);
 	}
 
-	perf_iommu_enable_event(event);
 	perf_event_update_userpage(event);
-
 }
 
 static void perf_iommu_read(struct perf_event *event)
 {
-	u64 count, prev, delta;
+	u64 count;
 	struct hw_perf_event *hwc = &event->hw;
 	struct amd_iommu *iommu = perf_event_2_iommu(event);
 
@@ -312,14 +319,11 @@ static void perf_iommu_read(struct perf_event *event)
 	/* IOMMU pc counter register is only 48 bits */
 	count &= GENMASK_ULL(47, 0);
 
-	prev = local64_read(&hwc->prev_count);
-	if (local64_cmpxchg(&hwc->prev_count, prev, count) != prev)
-		return;
-
-	/* Handle 48-bit counter overflow */
-	delta = (count << COUNTER_SHIFT) - (prev << COUNTER_SHIFT);
-	delta >>= COUNTER_SHIFT;
-	local64_add(delta, &event->count);
+	/*
+	 * Since the counter always start with value zero,
+	 * simply just accumulate the count for the event.
+	 */
+	local64_add(count, &event->count);
 }
 
 static void perf_iommu_stop(struct perf_event *event, int flags)
@@ -329,15 +333,16 @@ static void perf_iommu_stop(struct perf_event *event, int flags)
 	if (hwc->state & PERF_HES_UPTODATE)
 		return;
 
+	/*
+	 * To account for power-gating, in which reading the counter would
+	 * return zero, we need to read the register before disabling.
+	 */
+	perf_iommu_read(event);
+	hwc->state |= PERF_HES_UPTODATE;
+
 	perf_iommu_disable_event(event);
 	WARN_ON_ONCE(hwc->state & PERF_HES_STOPPED);
 	hwc->state |= PERF_HES_STOPPED;
-
-	if (hwc->state & PERF_HES_UPTODATE)
-		return;
-
-	perf_iommu_read(event);
-	hwc->state |= PERF_HES_UPTODATE;
 }
 
 static int perf_iommu_add(struct perf_event *event, int flags)