diff --git a/drivers/iommu/arm/arm-smmu/arm-smmu-impl.c b/drivers/iommu/arm/arm-smmu/arm-smmu-impl.c index f4ff124a1967dc9bb565c3586e5a53405f38bccc..a9861dcd0884350200814f041d84baf7ed786344 100644 --- a/drivers/iommu/arm/arm-smmu/arm-smmu-impl.c +++ b/drivers/iommu/arm/arm-smmu/arm-smmu-impl.c @@ -68,7 +68,8 @@ static int cavium_cfg_probe(struct arm_smmu_device *smmu) return 0; } -static int cavium_init_context(struct arm_smmu_domain *smmu_domain) +static int cavium_init_context(struct arm_smmu_domain *smmu_domain, + struct io_pgtable_cfg *pgtbl_cfg) { struct cavium_smmu *cs = container_of(smmu_domain->smmu, struct cavium_smmu, smmu); diff --git a/drivers/iommu/arm/arm-smmu/arm-smmu.c b/drivers/iommu/arm/arm-smmu/arm-smmu.c index 09c42af9f31e5dab510aba84c00355e5a1fde18f..37d8d49299b432cc8fed85e611c382a8ae53e257 100644 --- a/drivers/iommu/arm/arm-smmu/arm-smmu.c +++ b/drivers/iommu/arm/arm-smmu/arm-smmu.c @@ -795,11 +795,6 @@ static int arm_smmu_init_domain_context(struct iommu_domain *domain, cfg->asid = cfg->cbndx; smmu_domain->smmu = smmu; - if (smmu->impl && smmu->impl->init_context) { - ret = smmu->impl->init_context(smmu_domain); - if (ret) - goto out_unlock; - } pgtbl_cfg = (struct io_pgtable_cfg) { .pgsize_bitmap = smmu->pgsize_bitmap, @@ -810,6 +805,12 @@ static int arm_smmu_init_domain_context(struct iommu_domain *domain, .iommu_dev = smmu->dev, }; + if (smmu->impl && smmu->impl->init_context) { + ret = smmu->impl->init_context(smmu_domain, &pgtbl_cfg); + if (ret) + goto out_clear_smmu; + } + if (smmu_domain->non_strict) pgtbl_cfg.quirks |= IO_PGTABLE_QUIRK_NON_STRICT; diff --git a/drivers/iommu/arm/arm-smmu/arm-smmu.h b/drivers/iommu/arm/arm-smmu/arm-smmu.h index d890a4a968e8c9b6199e05ed121c3449eb8e1241..83294516ac080efe9a9141cda190dbee2a78942a 100644 --- a/drivers/iommu/arm/arm-smmu/arm-smmu.h +++ b/drivers/iommu/arm/arm-smmu/arm-smmu.h @@ -386,7 +386,8 @@ struct arm_smmu_impl { u64 val); int (*cfg_probe)(struct arm_smmu_device *smmu); int (*reset)(struct arm_smmu_device *smmu); - int (*init_context)(struct arm_smmu_domain *smmu_domain); + int (*init_context)(struct arm_smmu_domain *smmu_domain, + struct io_pgtable_cfg *cfg); void (*tlb_sync)(struct arm_smmu_device *smmu, int page, int sync, int status); int (*def_domain_type)(struct device *dev);