diff --git a/drivers/base/firmware_class.c b/drivers/base/firmware_class.c index d276e33880be1e632a95f5a960e4c0eab6c7f983..63f165c59da8bf93455f45dad007420bf1c08336 100644 --- a/drivers/base/firmware_class.c +++ b/drivers/base/firmware_class.c @@ -28,6 +28,7 @@ #include #include #include +#include #include @@ -308,12 +309,17 @@ static int fw_read_file_contents(struct file *file, struct firmware_buf *fw_buf) if (rc != size) { if (rc > 0) rc = -EIO; - vfree(buf); - return rc; + goto fail; } + rc = security_kernel_fw_from_file(file, buf, size); + if (rc) + goto fail; fw_buf->data = buf; fw_buf->size = size; return 0; +fail: + vfree(buf); + return rc; } static int fw_get_filesystem_firmware(struct device *device, @@ -617,6 +623,7 @@ static ssize_t firmware_loading_store(struct device *dev, { struct firmware_priv *fw_priv = to_firmware_priv(dev); struct firmware_buf *fw_buf; + ssize_t written = count; int loading = simple_strtol(buf, NULL, 10); int i; @@ -640,6 +647,8 @@ static ssize_t firmware_loading_store(struct device *dev, break; case 0: if (test_bit(FW_STATUS_LOADING, &fw_buf->status)) { + int rc; + set_bit(FW_STATUS_DONE, &fw_buf->status); clear_bit(FW_STATUS_LOADING, &fw_buf->status); @@ -649,10 +658,23 @@ static ssize_t firmware_loading_store(struct device *dev, * see the mapped 'buf->data' once the loading * is completed. * */ - if (fw_map_pages_buf(fw_buf)) + rc = fw_map_pages_buf(fw_buf); + if (rc) dev_err(dev, "%s: map pages failed\n", __func__); + else + rc = security_kernel_fw_from_file(NULL, + fw_buf->data, fw_buf->size); + + /* + * Same logic as fw_load_abort, only the DONE bit + * is ignored and we set ABORT only on failure. + */ list_del_init(&fw_buf->pending_list); + if (rc) { + set_bit(FW_STATUS_ABORT, &fw_buf->status); + written = rc; + } complete_all(&fw_buf->completion); break; } @@ -666,7 +688,7 @@ static ssize_t firmware_loading_store(struct device *dev, } out: mutex_unlock(&fw_lock); - return count; + return written; } static DEVICE_ATTR(loading, 0644, firmware_loading_show, firmware_loading_store);