diff --git a/drivers/net/wireless/intel/iwlwifi/mvm/d3.c b/drivers/net/wireless/intel/iwlwifi/mvm/d3.c
index 8cd2e67860754bebbcbb9715e4585f7458f154e3..69fd10a2f2a3197d9299d0b52a091da1bd6eeaa2 100644
--- a/drivers/net/wireless/intel/iwlwifi/mvm/d3.c
+++ b/drivers/net/wireless/intel/iwlwifi/mvm/d3.c
@@ -779,6 +779,9 @@ static int iwl_mvm_switch_to_d3(struct iwl_mvm *mvm)
 	 */
 	set_bit(IWL_MVM_STATUS_IN_HW_RESTART, &mvm->status);
 
+	/* the fw is reset, so all the keys are cleared */
+	memset(mvm->fw_key_table, 0, sizeof(mvm->fw_key_table));
+
 	mvm->ptk_ivlen = 0;
 	mvm->ptk_icvlen = 0;
 	mvm->ptk_ivlen = 0;
diff --git a/drivers/net/wireless/intel/iwlwifi/mvm/mac80211.c b/drivers/net/wireless/intel/iwlwifi/mvm/mac80211.c
index d239e97ab98a766f90f4bd95d2cea8d1582cd916..be5703c44c2241740474c8ab71ea0d5f1f174f9c 100644
--- a/drivers/net/wireless/intel/iwlwifi/mvm/mac80211.c
+++ b/drivers/net/wireless/intel/iwlwifi/mvm/mac80211.c
@@ -984,6 +984,7 @@ static void iwl_mvm_restart_cleanup(struct iwl_mvm *mvm)
 	mvm->d0i3_ap_sta_id = IWL_MVM_STATION_COUNT;
 
 	iwl_mvm_reset_phy_ctxts(mvm);
+	memset(mvm->fw_key_table, 0, sizeof(mvm->fw_key_table));
 	memset(mvm->sta_drained, 0, sizeof(mvm->sta_drained));
 	memset(mvm->tfd_drained, 0, sizeof(mvm->tfd_drained));
 	memset(&mvm->last_bt_notif, 0, sizeof(mvm->last_bt_notif));
diff --git a/drivers/net/wireless/intel/iwlwifi/mvm/sta.c b/drivers/net/wireless/intel/iwlwifi/mvm/sta.c
index 4148ebea5373e644d31a12e5f3d2df1aae8fe814..92edacc298da5f7312fd22c28cde4d9445e354c6 100644
--- a/drivers/net/wireless/intel/iwlwifi/mvm/sta.c
+++ b/drivers/net/wireless/intel/iwlwifi/mvm/sta.c
@@ -1198,8 +1198,6 @@ static int iwl_mvm_set_fw_key_idx(struct iwl_mvm *mvm)
 	if (max_offs < 0)
 		return STA_KEY_IDX_INVALID;
 
-	__set_bit(max_offs, mvm->fw_key_table);
-
 	return max_offs;
 }
 
@@ -1507,10 +1505,8 @@ int iwl_mvm_set_sta_key(struct iwl_mvm *mvm,
 	}
 
 	ret = __iwl_mvm_set_sta_key(mvm, vif, sta, keyconf, key_offset, mcast);
-	if (ret) {
-		__clear_bit(keyconf->hw_key_idx, mvm->fw_key_table);
+	if (ret)
 		goto end;
-	}
 
 	/*
 	 * For WEP, the same key is used for multicast and unicast. Upload it
@@ -1523,11 +1519,13 @@ int iwl_mvm_set_sta_key(struct iwl_mvm *mvm,
 		ret = __iwl_mvm_set_sta_key(mvm, vif, sta, keyconf,
 					    key_offset, !mcast);
 		if (ret) {
-			__clear_bit(keyconf->hw_key_idx, mvm->fw_key_table);
 			__iwl_mvm_remove_sta_key(mvm, sta_id, keyconf, mcast);
+			goto end;
 		}
 	}
 
+	__set_bit(key_offset, mvm->fw_key_table);
+
 end:
 	IWL_DEBUG_WEP(mvm, "key: cipher=%x len=%d idx=%d sta=%pM ret=%d\n",
 		      keyconf->cipher, keyconf->keylen, keyconf->keyidx,