amd_iommu_v2.c 21.3 KB
Newer Older
1 2
/*
 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
J
Joerg Roedel 已提交
3
 * Author: Joerg Roedel <jroedel@suse.de>
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 as published
 * by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
 */

19
#include <linux/mmu_notifier.h>
20 21
#include <linux/amd-iommu.h>
#include <linux/mm_types.h>
22
#include <linux/profile.h>
23
#include <linux/module.h>
24
#include <linux/sched.h>
25
#include <linux/iommu.h>
26
#include <linux/wait.h>
27 28 29
#include <linux/pci.h>
#include <linux/gfp.h>

30
#include "amd_iommu_types.h"
31
#include "amd_iommu_proto.h"
32 33

MODULE_LICENSE("GPL v2");
J
Joerg Roedel 已提交
34
MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
35

36 37 38 39 40 41
#define MAX_DEVICES		0x10000
#define PRI_QUEUE_SIZE		512

struct pri_queue {
	atomic_t inflight;
	bool finish;
42
	int status;
43 44 45 46 47
};

struct pasid_state {
	struct list_head list;			/* For global state-list */
	atomic_t count;				/* Reference count */
48
	unsigned mmu_notifier_count;		/* Counting nested mmu_notifier
49
						   calls */
50
	struct mm_struct *mm;			/* mm_struct for the faults */
51
	struct mmu_notifier mn;                 /* mmu_notifier handle */
52 53 54
	struct pri_queue pri[PRI_QUEUE_SIZE];	/* PRI tag states */
	struct device_state *device_state;	/* Link to our device_state */
	int pasid;				/* PASID index */
55 56
	bool invalid;				/* Used during setup and
						   teardown of the pasid */
57 58
	spinlock_t lock;			/* Protect pri_queues and
						   mmu_notifer_count */
59
	wait_queue_head_t wq;			/* To wait for count == 0 */
60 61 62
};

struct device_state {
63 64
	struct list_head list;
	u16 devid;
65 66 67 68 69 70
	atomic_t count;
	struct pci_dev *pdev;
	struct pasid_state **states;
	struct iommu_domain *domain;
	int pasid_levels;
	int max_pasids;
J
Joerg Roedel 已提交
71
	amd_iommu_invalid_ppr_cb inv_ppr_cb;
72
	amd_iommu_invalidate_ctx inv_ctx_cb;
73
	spinlock_t lock;
74 75 76 77 78 79 80 81 82 83 84 85 86 87
	wait_queue_head_t wq;
};

struct fault {
	struct work_struct work;
	struct device_state *dev_state;
	struct pasid_state *state;
	struct mm_struct *mm;
	u64 address;
	u16 devid;
	u16 pasid;
	u16 tag;
	u16 finish;
	u16 flags;
88 89
};

90
static LIST_HEAD(state_list);
91 92
static spinlock_t state_lock;

93 94
static struct workqueue_struct *iommu_wq;

95
static void free_pasid_states(struct device_state *dev_state);
96 97 98 99 100 101 102 103 104 105 106

static u16 device_id(struct pci_dev *pdev)
{
	u16 devid;

	devid = pdev->bus->number;
	devid = (devid << 8) | pdev->devfn;

	return devid;
}

107 108
static struct device_state *__get_device_state(u16 devid)
{
109 110 111 112 113 114 115 116
	struct device_state *dev_state;

	list_for_each_entry(dev_state, &state_list, list) {
		if (dev_state->devid == devid)
			return dev_state;
	}

	return NULL;
117 118
}

119 120 121 122 123 124
static struct device_state *get_device_state(u16 devid)
{
	struct device_state *dev_state;
	unsigned long flags;

	spin_lock_irqsave(&state_lock, flags);
125
	dev_state = __get_device_state(devid);
126 127 128 129 130 131 132 133 134
	if (dev_state != NULL)
		atomic_inc(&dev_state->count);
	spin_unlock_irqrestore(&state_lock, flags);

	return dev_state;
}

static void free_device_state(struct device_state *dev_state)
{
135 136 137 138
	/*
	 * First detach device from domain - No more PRI requests will arrive
	 * from that device after it is unbound from the IOMMUv2 domain.
	 */
139
	iommu_detach_device(dev_state->domain, &dev_state->pdev->dev);
140 141

	/* Everything is down now, free the IOMMUv2 domain */
142
	iommu_domain_free(dev_state->domain);
143 144

	/* Finally get rid of the device-state */
145 146 147 148 149 150
	kfree(dev_state);
}

static void put_device_state(struct device_state *dev_state)
{
	if (atomic_dec_and_test(&dev_state->count))
151
		wake_up(&dev_state->wq);
152 153
}

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
/* Must be called under dev_state->lock */
static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
						  int pasid, bool alloc)
{
	struct pasid_state **root, **ptr;
	int level, index;

	level = dev_state->pasid_levels;
	root  = dev_state->states;

	while (true) {

		index = (pasid >> (9 * level)) & 0x1ff;
		ptr   = &root[index];

		if (level == 0)
			break;

		if (*ptr == NULL) {
			if (!alloc)
				return NULL;

			*ptr = (void *)get_zeroed_page(GFP_ATOMIC);
			if (*ptr == NULL)
				return NULL;
		}

		root   = (struct pasid_state **)*ptr;
		level -= 1;
	}

	return ptr;
}

static int set_pasid_state(struct device_state *dev_state,
			   struct pasid_state *pasid_state,
			   int pasid)
{
	struct pasid_state **ptr;
	unsigned long flags;
	int ret;

	spin_lock_irqsave(&dev_state->lock, flags);
	ptr = __get_pasid_state_ptr(dev_state, pasid, true);

	ret = -ENOMEM;
	if (ptr == NULL)
		goto out_unlock;

	ret = -ENOMEM;
	if (*ptr != NULL)
		goto out_unlock;

	*ptr = pasid_state;

	ret = 0;

out_unlock:
	spin_unlock_irqrestore(&dev_state->lock, flags);

	return ret;
}

static void clear_pasid_state(struct device_state *dev_state, int pasid)
{
	struct pasid_state **ptr;
	unsigned long flags;

	spin_lock_irqsave(&dev_state->lock, flags);
	ptr = __get_pasid_state_ptr(dev_state, pasid, true);

	if (ptr == NULL)
		goto out_unlock;

	*ptr = NULL;

out_unlock:
	spin_unlock_irqrestore(&dev_state->lock, flags);
}

static struct pasid_state *get_pasid_state(struct device_state *dev_state,
					   int pasid)
{
	struct pasid_state **ptr, *ret = NULL;
	unsigned long flags;

	spin_lock_irqsave(&dev_state->lock, flags);
	ptr = __get_pasid_state_ptr(dev_state, pasid, false);

	if (ptr == NULL)
		goto out_unlock;

	ret = *ptr;
	if (ret)
		atomic_inc(&ret->count);

out_unlock:
	spin_unlock_irqrestore(&dev_state->lock, flags);

	return ret;
}

static void free_pasid_state(struct pasid_state *pasid_state)
{
	kfree(pasid_state);
}

static void put_pasid_state(struct pasid_state *pasid_state)
{
263
	if (atomic_dec_and_test(&pasid_state->count))
264
		wake_up(&pasid_state->wq);
265 266
}

267 268
static void put_pasid_state_wait(struct pasid_state *pasid_state)
{
269
	wait_event(pasid_state->wq, !atomic_read(&pasid_state->count));
270 271 272
	free_pasid_state(pasid_state);
}

273
static void unbind_pasid(struct pasid_state *pasid_state)
274 275 276 277 278
{
	struct iommu_domain *domain;

	domain = pasid_state->device_state->domain;

279 280 281 282 283 284 285 286 287 288
	/*
	 * Mark pasid_state as invalid, no more faults will we added to the
	 * work queue after this is visible everywhere.
	 */
	pasid_state->invalid = true;

	/* Make sure this is visible */
	smp_wmb();

	/* After this the device/pasid can't access the mm anymore */
289 290 291 292 293 294
	amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);

	/* Make sure no more pending faults are in the queue */
	flush_workqueue(iommu_wq);
}

295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
static void free_pasid_states_level1(struct pasid_state **tbl)
{
	int i;

	for (i = 0; i < 512; ++i) {
		if (tbl[i] == NULL)
			continue;

		free_page((unsigned long)tbl[i]);
	}
}

static void free_pasid_states_level2(struct pasid_state **tbl)
{
	struct pasid_state **ptr;
	int i;

	for (i = 0; i < 512; ++i) {
		if (tbl[i] == NULL)
			continue;

		ptr = (struct pasid_state **)tbl[i];
		free_pasid_states_level1(ptr);
	}
}

static void free_pasid_states(struct device_state *dev_state)
{
	struct pasid_state *pasid_state;
	int i;

	for (i = 0; i < dev_state->max_pasids; ++i) {
		pasid_state = get_pasid_state(dev_state, i);
		if (pasid_state == NULL)
			continue;

		put_pasid_state(pasid_state);
332 333 334 335 336 337

		/*
		 * This will call the mn_release function and
		 * unbind the PASID
		 */
		mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
338 339

		put_pasid_state_wait(pasid_state); /* Reference taken in
340
						      amd_iommu_bind_pasid */
341 342 343

		/* Drop reference taken in amd_iommu_bind_pasid */
		put_device_state(dev_state);
344 345 346 347 348 349 350 351 352 353 354 355
	}

	if (dev_state->pasid_levels == 2)
		free_pasid_states_level2(dev_state->states);
	else if (dev_state->pasid_levels == 1)
		free_pasid_states_level1(dev_state->states);
	else if (dev_state->pasid_levels != 0)
		BUG();

	free_page((unsigned long)dev_state->states);
}

356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
{
	return container_of(mn, struct pasid_state, mn);
}

static void __mn_flush_page(struct mmu_notifier *mn,
			    unsigned long address)
{
	struct pasid_state *pasid_state;
	struct device_state *dev_state;

	pasid_state = mn_to_state(mn);
	dev_state   = pasid_state->device_state;

	amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, address);
}

static int mn_clear_flush_young(struct mmu_notifier *mn,
				struct mm_struct *mm,
A
Andres Lagar-Cavilla 已提交
375 376
				unsigned long start,
				unsigned long end)
377
{
A
Andres Lagar-Cavilla 已提交
378 379
	for (; start < end; start += PAGE_SIZE)
		__mn_flush_page(mn, start);
380 381 382 383 384 385 386 387 388 389 390

	return 0;
}

static void mn_invalidate_page(struct mmu_notifier *mn,
			       struct mm_struct *mm,
			       unsigned long address)
{
	__mn_flush_page(mn, address);
}

391 392 393
static void mn_invalidate_range(struct mmu_notifier *mn,
				struct mm_struct *mm,
				unsigned long start, unsigned long end)
394 395 396 397 398 399 400
{
	struct pasid_state *pasid_state;
	struct device_state *dev_state;

	pasid_state = mn_to_state(mn);
	dev_state   = pasid_state->device_state;

401 402 403 404 405
	if ((start ^ (end - 1)) < PAGE_SIZE)
		amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
				     start);
	else
		amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
406 407
}

408 409 410 411
static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
{
	struct pasid_state *pasid_state;
	struct device_state *dev_state;
412
	bool run_inv_ctx_cb;
413 414 415

	might_sleep();

416 417 418
	pasid_state    = mn_to_state(mn);
	dev_state      = pasid_state->device_state;
	run_inv_ctx_cb = !pasid_state->invalid;
419

420
	if (run_inv_ctx_cb && pasid_state->device_state->inv_ctx_cb)
421 422
		dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);

423
	unbind_pasid(pasid_state);
424 425
}

426
static struct mmu_notifier_ops iommu_mn = {
427
	.release		= mn_release,
428 429
	.clear_flush_young      = mn_clear_flush_young,
	.invalidate_page        = mn_invalidate_page,
430
	.invalidate_range       = mn_invalidate_range,
431 432
};

433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
static void set_pri_tag_status(struct pasid_state *pasid_state,
			       u16 tag, int status)
{
	unsigned long flags;

	spin_lock_irqsave(&pasid_state->lock, flags);
	pasid_state->pri[tag].status = status;
	spin_unlock_irqrestore(&pasid_state->lock, flags);
}

static void finish_pri_tag(struct device_state *dev_state,
			   struct pasid_state *pasid_state,
			   u16 tag)
{
	unsigned long flags;

	spin_lock_irqsave(&pasid_state->lock, flags);
	if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
	    pasid_state->pri[tag].finish) {
		amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
				       pasid_state->pri[tag].status, tag);
		pasid_state->pri[tag].finish = false;
		pasid_state->pri[tag].status = PPR_SUCCESS;
	}
	spin_unlock_irqrestore(&pasid_state->lock, flags);
}

460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
static void handle_fault_error(struct fault *fault)
{
	int status;

	if (!fault->dev_state->inv_ppr_cb) {
		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
		return;
	}

	status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
					      fault->pasid,
					      fault->address,
					      fault->flags);
	switch (status) {
	case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
		set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
		break;
	case AMD_IOMMU_INV_PRI_RSP_INVALID:
		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
		break;
	case AMD_IOMMU_INV_PRI_RSP_FAIL:
		set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
		break;
	default:
		BUG();
	}
}

488 489 490
static void do_fault(struct work_struct *work)
{
	struct fault *fault = container_of(work, struct fault, work);
491 492 493 494
	struct mm_struct *mm;
	struct vm_area_struct *vma;
	u64 address;
	int ret, write;
495 496 497

	write = !!(fault->flags & PPR_FAULT_WRITE);

498 499 500 501 502 503 504 505 506 507
	mm = fault->state->mm;
	address = fault->address;

	down_read(&mm->mmap_sem);
	vma = find_extend_vma(mm, address);
	if (!vma || address < vma->vm_start) {
		/* failed to get a vma in the right range */
		up_read(&mm->mmap_sem);
		handle_fault_error(fault);
		goto out;
J
Joerg Roedel 已提交
508
	}
509

510 511 512 513 514 515 516 517 518 519 520
	ret = handle_mm_fault(mm, vma, address, write);
	if (ret & VM_FAULT_ERROR) {
		/* failed to service fault */
		up_read(&mm->mmap_sem);
		handle_fault_error(fault);
		goto out;
	}

	up_read(&mm->mmap_sem);

out:
521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
	finish_pri_tag(fault->dev_state, fault->state, fault->tag);

	put_pasid_state(fault->state);

	kfree(fault);
}

static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
{
	struct amd_iommu_fault *iommu_fault;
	struct pasid_state *pasid_state;
	struct device_state *dev_state;
	unsigned long flags;
	struct fault *fault;
	bool finish;
	u16 tag;
	int ret;

	iommu_fault = data;
	tag         = iommu_fault->tag & 0x1ff;
	finish      = (iommu_fault->tag >> 9) & 1;

	ret = NOTIFY_DONE;
	dev_state = get_device_state(iommu_fault->device_id);
	if (dev_state == NULL)
		goto out;

	pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
549
	if (pasid_state == NULL || pasid_state->invalid) {
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
		/* We know the device but not the PASID -> send INVALID */
		amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
				       PPR_INVALID, tag);
		goto out_drop_state;
	}

	spin_lock_irqsave(&pasid_state->lock, flags);
	atomic_inc(&pasid_state->pri[tag].inflight);
	if (finish)
		pasid_state->pri[tag].finish = true;
	spin_unlock_irqrestore(&pasid_state->lock, flags);

	fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
	if (fault == NULL) {
		/* We are OOM - send success and let the device re-fault */
		finish_pri_tag(dev_state, pasid_state, tag);
		goto out_drop_state;
	}

	fault->dev_state = dev_state;
	fault->address   = iommu_fault->address;
	fault->state     = pasid_state;
	fault->tag       = tag;
	fault->finish    = finish;
574
	fault->pasid     = iommu_fault->pasid;
575 576 577 578 579 580 581 582
	fault->flags     = iommu_fault->flags;
	INIT_WORK(&fault->work, do_fault);

	queue_work(iommu_wq, &fault->work);

	ret = NOTIFY_OK;

out_drop_state:
583 584 585 586

	if (ret != NOTIFY_OK && pasid_state)
		put_pasid_state(pasid_state);

587 588 589 590 591 592 593 594 595 596
	put_device_state(dev_state);

out:
	return ret;
}

static struct notifier_block ppr_nb = {
	.notifier_call = ppr_notifier,
};

597 598 599 600 601
int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
			 struct task_struct *task)
{
	struct pasid_state *pasid_state;
	struct device_state *dev_state;
602
	struct mm_struct *mm;
603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
	u16 devid;
	int ret;

	might_sleep();

	if (!amd_iommu_v2_supported())
		return -ENODEV;

	devid     = device_id(pdev);
	dev_state = get_device_state(devid);

	if (dev_state == NULL)
		return -EINVAL;

	ret = -EINVAL;
	if (pasid < 0 || pasid >= dev_state->max_pasids)
		goto out;

	ret = -ENOMEM;
	pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
	if (pasid_state == NULL)
		goto out;

626

627
	atomic_set(&pasid_state->count, 1);
628
	init_waitqueue_head(&pasid_state->wq);
629 630
	spin_lock_init(&pasid_state->lock);

631 632
	mm                        = get_task_mm(task);
	pasid_state->mm           = mm;
633 634
	pasid_state->device_state = dev_state;
	pasid_state->pasid        = pasid;
635 636
	pasid_state->invalid      = true; /* Mark as valid only if we are
					     done with setting up the pasid */
637
	pasid_state->mn.ops       = &iommu_mn;
638 639 640 641

	if (pasid_state->mm == NULL)
		goto out_free;

642
	mmu_notifier_register(&pasid_state->mn, mm);
643

644 645
	ret = set_pasid_state(dev_state, pasid_state, pasid);
	if (ret)
646
		goto out_unregister;
647 648 649 650 651 652

	ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
					__pa(pasid_state->mm->pgd));
	if (ret)
		goto out_clear_state;

653 654 655
	/* Now we are ready to handle faults */
	pasid_state->invalid = false;

656 657 658 659 660 661 662
	/*
	 * Drop the reference to the mm_struct here. We rely on the
	 * mmu_notifier release call-back to inform us when the mm
	 * is going away.
	 */
	mmput(mm);

663 664 665 666 667
	return 0;

out_clear_state:
	clear_pasid_state(dev_state, pasid);

668
out_unregister:
669
	mmu_notifier_unregister(&pasid_state->mn, mm);
670

671
out_free:
672
	mmput(mm);
673
	free_pasid_state(pasid_state);
674 675 676 677 678 679 680 681 682 683

out:
	put_device_state(dev_state);

	return ret;
}
EXPORT_SYMBOL(amd_iommu_bind_pasid);

void amd_iommu_unbind_pasid(struct pci_dev *pdev, int pasid)
{
684
	struct pasid_state *pasid_state;
685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700
	struct device_state *dev_state;
	u16 devid;

	might_sleep();

	if (!amd_iommu_v2_supported())
		return;

	devid = device_id(pdev);
	dev_state = get_device_state(devid);
	if (dev_state == NULL)
		return;

	if (pasid < 0 || pasid >= dev_state->max_pasids)
		goto out;

701 702 703 704 705 706 707 708 709
	pasid_state = get_pasid_state(dev_state, pasid);
	if (pasid_state == NULL)
		goto out;
	/*
	 * Drop reference taken here. We are safe because we still hold
	 * the reference taken in the amd_iommu_bind_pasid function.
	 */
	put_pasid_state(pasid_state);

710 711 712
	/* Clear the pasid state so that the pasid can be re-used */
	clear_pasid_state(dev_state, pasid_state->pasid);

713
	/*
714 715
	 * Call mmu_notifier_unregister to drop our reference
	 * to pasid_state->mm
716
	 */
717
	mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
718

719
	put_pasid_state_wait(pasid_state); /* Reference taken in
720
					      amd_iommu_bind_pasid */
721
out:
722 723 724 725
	/* Drop reference taken in this function */
	put_device_state(dev_state);

	/* Drop reference taken in amd_iommu_bind_pasid */
726 727 728 729
	put_device_state(dev_state);
}
EXPORT_SYMBOL(amd_iommu_unbind_pasid);

730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751
int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
{
	struct device_state *dev_state;
	unsigned long flags;
	int ret, tmp;
	u16 devid;

	might_sleep();

	if (!amd_iommu_v2_supported())
		return -ENODEV;

	if (pasids <= 0 || pasids > (PASID_MASK + 1))
		return -EINVAL;

	devid = device_id(pdev);

	dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
	if (dev_state == NULL)
		return -ENOMEM;

	spin_lock_init(&dev_state->lock);
752
	init_waitqueue_head(&dev_state->wq);
753 754
	dev_state->pdev  = pdev;
	dev_state->devid = devid;
755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783

	tmp = pasids;
	for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
		dev_state->pasid_levels += 1;

	atomic_set(&dev_state->count, 1);
	dev_state->max_pasids = pasids;

	ret = -ENOMEM;
	dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
	if (dev_state->states == NULL)
		goto out_free_dev_state;

	dev_state->domain = iommu_domain_alloc(&pci_bus_type);
	if (dev_state->domain == NULL)
		goto out_free_states;

	amd_iommu_domain_direct_map(dev_state->domain);

	ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
	if (ret)
		goto out_free_domain;

	ret = iommu_attach_device(dev_state->domain, &pdev->dev);
	if (ret != 0)
		goto out_free_domain;

	spin_lock_irqsave(&state_lock, flags);

784
	if (__get_device_state(devid) != NULL) {
785 786 787 788 789
		spin_unlock_irqrestore(&state_lock, flags);
		ret = -EBUSY;
		goto out_free_domain;
	}

790
	list_add_tail(&dev_state->list, &state_list);
791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821

	spin_unlock_irqrestore(&state_lock, flags);

	return 0;

out_free_domain:
	iommu_domain_free(dev_state->domain);

out_free_states:
	free_page((unsigned long)dev_state->states);

out_free_dev_state:
	kfree(dev_state);

	return ret;
}
EXPORT_SYMBOL(amd_iommu_init_device);

void amd_iommu_free_device(struct pci_dev *pdev)
{
	struct device_state *dev_state;
	unsigned long flags;
	u16 devid;

	if (!amd_iommu_v2_supported())
		return;

	devid = device_id(pdev);

	spin_lock_irqsave(&state_lock, flags);

822
	dev_state = __get_device_state(devid);
823 824 825 826 827
	if (dev_state == NULL) {
		spin_unlock_irqrestore(&state_lock, flags);
		return;
	}

828
	list_del(&dev_state->list);
829 830 831

	spin_unlock_irqrestore(&state_lock, flags);

832 833 834
	/* Get rid of any remaining pasid states */
	free_pasid_states(dev_state);

835 836 837 838 839 840 841
	put_device_state(dev_state);
	/*
	 * Wait until the last reference is dropped before freeing
	 * the device state.
	 */
	wait_event(dev_state->wq, !atomic_read(&dev_state->count));
	free_device_state(dev_state);
842 843 844
}
EXPORT_SYMBOL(amd_iommu_free_device);

J
Joerg Roedel 已提交
845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860
int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
				 amd_iommu_invalid_ppr_cb cb)
{
	struct device_state *dev_state;
	unsigned long flags;
	u16 devid;
	int ret;

	if (!amd_iommu_v2_supported())
		return -ENODEV;

	devid = device_id(pdev);

	spin_lock_irqsave(&state_lock, flags);

	ret = -EINVAL;
861
	dev_state = __get_device_state(devid);
J
Joerg Roedel 已提交
862 863 864 865 866 867 868 869 870 871 872 873 874 875
	if (dev_state == NULL)
		goto out_unlock;

	dev_state->inv_ppr_cb = cb;

	ret = 0;

out_unlock:
	spin_unlock_irqrestore(&state_lock, flags);

	return ret;
}
EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);

876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891
int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
				    amd_iommu_invalidate_ctx cb)
{
	struct device_state *dev_state;
	unsigned long flags;
	u16 devid;
	int ret;

	if (!amd_iommu_v2_supported())
		return -ENODEV;

	devid = device_id(pdev);

	spin_lock_irqsave(&state_lock, flags);

	ret = -EINVAL;
892
	dev_state = __get_device_state(devid);
893 894 895 896 897 898 899 900 901 902 903 904 905 906
	if (dev_state == NULL)
		goto out_unlock;

	dev_state->inv_ctx_cb = cb;

	ret = 0;

out_unlock:
	spin_unlock_irqrestore(&state_lock, flags);

	return ret;
}
EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);

907 908
static int __init amd_iommu_v2_init(void)
{
909
	int ret;
910

J
Joerg Roedel 已提交
911
	pr_info("AMD IOMMUv2 driver by Joerg Roedel <jroedel@suse.de>\n");
912 913

	if (!amd_iommu_v2_supported()) {
M
Masanari Iida 已提交
914
		pr_info("AMD IOMMUv2 functionality not available on this system\n");
915 916 917 918 919 920
		/*
		 * Load anyway to provide the symbols to other modules
		 * which may use AMD IOMMUv2 optionally.
		 */
		return 0;
	}
921

922 923
	spin_lock_init(&state_lock);

924 925
	ret = -ENOMEM;
	iommu_wq = create_workqueue("amd_iommu_v2");
926
	if (iommu_wq == NULL)
927
		goto out;
928

929 930
	amd_iommu_register_ppr_notifier(&ppr_nb);

931
	return 0;
932

933
out:
934
	return ret;
935 936 937 938
}

static void __exit amd_iommu_v2_exit(void)
{
939 940 941
	struct device_state *dev_state;
	int i;

942 943 944
	if (!amd_iommu_v2_supported())
		return;

945 946 947 948 949 950 951 952
	amd_iommu_unregister_ppr_notifier(&ppr_nb);

	flush_workqueue(iommu_wq);

	/*
	 * The loop below might call flush_workqueue(), so call
	 * destroy_workqueue() after it
	 */
953 954 955 956 957 958 959 960 961
	for (i = 0; i < MAX_DEVICES; ++i) {
		dev_state = get_device_state(i);

		if (dev_state == NULL)
			continue;

		WARN_ON_ONCE(1);

		put_device_state(dev_state);
962
		amd_iommu_free_device(dev_state->pdev);
963 964
	}

965
	destroy_workqueue(iommu_wq);
966 967 968 969
}

module_init(amd_iommu_v2_init);
module_exit(amd_iommu_v2_exit);