diff --git a/drivers/crypto/hisilicon/migration/acc_vf_migration.c b/drivers/crypto/hisilicon/migration/acc_vf_migration.c index cd38254c6598fb0dc9ebb7e7f55b990f70700015..cdaa55aa8716650fff1feb19c8c3153e368ecb65 100644 --- a/drivers/crypto/hisilicon/migration/acc_vf_migration.c +++ b/drivers/crypto/hisilicon/migration/acc_vf_migration.c @@ -1739,6 +1739,8 @@ static void acc_vf_remove(void *vendor_data) static struct vfio_pci_vendor_driver_ops sec_vf_mig_ops = { .owner = THIS_MODULE, .name = "hisi_sec2", + .vendor = PCI_VENDOR_ID_HUAWEI, + .device = PCI_DEVICE_ID_HUAWEI_SEC_VF, .probe = acc_vf_probe, .remove = acc_vf_remove, .device_ops = &acc_vf_device_ops_node, @@ -1747,6 +1749,8 @@ static struct vfio_pci_vendor_driver_ops sec_vf_mig_ops = { static struct vfio_pci_vendor_driver_ops hpre_vf_mig_ops = { .owner = THIS_MODULE, .name = "hisi_hpre", + .vendor = PCI_VENDOR_ID_HUAWEI, + .device = PCI_DEVICE_ID_HUAWEI_HPRE_VF, .probe = acc_vf_probe, .remove = acc_vf_remove, .device_ops = &acc_vf_device_ops_node, @@ -1755,6 +1759,8 @@ static struct vfio_pci_vendor_driver_ops hpre_vf_mig_ops = { static struct vfio_pci_vendor_driver_ops zip_vf_mig_ops = { .owner = THIS_MODULE, .name = "hisi_zip", + .vendor = PCI_VENDOR_ID_HUAWEI, + .device = PCI_DEVICE_ID_HUAWEI_ZIP_VF, .probe = acc_vf_probe, .remove = acc_vf_remove, .device_ops = &acc_vf_device_ops_node, @@ -1773,7 +1779,9 @@ static int __init acc_vf_module_init(void) static void __exit acc_vf_module_exit(void) { - vfio_pci_unregister_vendor_driver(&acc_vf_device_ops_node); + vfio_pci_unregister_vendor_driver(&sec_vf_mig_ops); + vfio_pci_unregister_vendor_driver(&hpre_vf_mig_ops); + vfio_pci_unregister_vendor_driver(&zip_vf_mig_ops); }; module_init(acc_vf_module_init); module_exit(acc_vf_module_exit); diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c index 6d7ae4e3b98318f7edaabc020822098098612dbc..c04b34247067454c53df42d800e9bf0efdb570fb 100644 --- a/drivers/vfio/pci/vfio_pci.c +++ b/drivers/vfio/pci/vfio_pci.c @@ -2077,6 +2077,10 @@ static int probe_vendor_drivers(struct vfio_pci_device *vdev) list_for_each_entry(driver, &vfio_pci.vendor_drivers_list, next) { void *data; + if (vdev->pdev->vendor != driver->ops->vendor || + vdev->pdev->device != driver->ops->device) + continue; + if (!try_module_get(driver->ops->owner)) continue; @@ -2604,7 +2608,8 @@ int __vfio_pci_register_vendor_driver(struct vfio_pci_vendor_driver_ops *ops) /* Check for duplicates */ list_for_each_entry(tmp, &vfio_pci.vendor_drivers_list, next) { - if (tmp->ops->device_ops == ops->device_ops) { + if (tmp->ops->vendor == ops->vendor && + tmp->ops->vendor == ops->device) { mutex_unlock(&vfio_pci.vendor_drivers_lock); kfree(driver); return -EINVAL; @@ -2622,14 +2627,15 @@ int __vfio_pci_register_vendor_driver(struct vfio_pci_vendor_driver_ops *ops) } EXPORT_SYMBOL_GPL(__vfio_pci_register_vendor_driver); -void vfio_pci_unregister_vendor_driver(struct vfio_device_ops *device_ops) +void vfio_pci_unregister_vendor_driver(struct vfio_pci_vendor_driver_ops *ops) { struct vfio_pci_vendor_driver *driver, *tmp; mutex_lock(&vfio_pci.vendor_drivers_lock); list_for_each_entry_safe(driver, tmp, &vfio_pci.vendor_drivers_list, next) { - if (driver->ops->device_ops == device_ops) { + if (driver->ops->vendor == ops->vendor && + driver->ops->device == ops->device) { list_del(&driver->next); mutex_unlock(&vfio_pci.vendor_drivers_lock); kfree(driver); diff --git a/include/linux/vfio.h b/include/linux/vfio.h index 9a2217f13753a1186b2e9e569cb60484a4f47e19..f516784e0356f866ed36110d433df58934184edc 100644 --- a/include/linux/vfio.h +++ b/include/linux/vfio.h @@ -253,12 +253,15 @@ extern int vfio_pci_set_vendor_regions(void *device_data, struct vfio_pci_vendor_driver_ops { char *name; struct module *owner; + /* Used to match device */ + unsigned short vendor; + unsigned short device; void *(*probe)(struct pci_dev *pdev); void (*remove)(void *vendor_data); struct vfio_device_ops *device_ops; }; int __vfio_pci_register_vendor_driver(struct vfio_pci_vendor_driver_ops *ops); -void vfio_pci_unregister_vendor_driver(struct vfio_device_ops *device_ops); +void vfio_pci_unregister_vendor_driver(struct vfio_pci_vendor_driver_ops *ops); #define vfio_pci_register_vendor_driver(__name, __probe, __remove, \ __device_ops) \