diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_psp.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_psp.c
index c224c5caba5b90b0db8a0559e7f101160468909a..e19369efb840e9f5c11f2f51ba31af5c691d08bb 100644
--- a/drivers/gpu/drm/amd/amdgpu/amdgpu_psp.c
+++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_psp.c
@@ -333,14 +333,11 @@ static int psp_load_fw(struct amdgpu_device *adev)
 {
 	int ret;
 	struct psp_context *psp = &adev->psp;
-	struct psp_gfx_cmd_resp *cmd;
 
-	cmd = kzalloc(sizeof(struct psp_gfx_cmd_resp), GFP_KERNEL);
-	if (!cmd)
+	psp->cmd = kzalloc(sizeof(struct psp_gfx_cmd_resp), GFP_KERNEL);
+	if (!psp->cmd)
 		return -ENOMEM;
 
-	psp->cmd = cmd;
-
 	ret = amdgpu_bo_create_kernel(adev, PSP_1_MEG, PSP_1_MEG,
 				      AMDGPU_GEM_DOMAIN_GTT,
 				      &psp->fw_pri_bo,
@@ -379,8 +376,6 @@ static int psp_load_fw(struct amdgpu_device *adev)
 	if (ret)
 		goto failed_mem;
 
-	kfree(cmd);
-
 	return 0;
 
 failed_mem:
@@ -390,7 +385,8 @@ static int psp_load_fw(struct amdgpu_device *adev)
 	amdgpu_bo_free_kernel(&psp->fw_pri_bo,
 			      &psp->fw_pri_mc_addr, &psp->fw_pri_buf);
 failed:
-	kfree(cmd);
+	kfree(psp->cmd);
+	psp->cmd = NULL;
 	return ret;
 }
 
@@ -450,6 +446,9 @@ static int psp_hw_fini(void *handle)
 		amdgpu_bo_free_kernel(&psp->fence_buf_bo,
 				      &psp->fence_buf_mc_addr, &psp->fence_buf);
 
+	kfree(psp->cmd);
+	psp->cmd = NULL;
+
 	return 0;
 }