diff --git a/mm/mempolicy.c b/mm/mempolicy.c index 92daa267baf2e7783af034f8302642931a202c72..f0728ae74672262939e58e920be13f60b893e18f 100644 --- a/mm/mempolicy.c +++ b/mm/mempolicy.c @@ -607,24 +607,39 @@ check_range(struct mm_struct *mm, unsigned long start, unsigned long end, return first; } -/* Apply policy to a single VMA */ -static int policy_vma(struct vm_area_struct *vma, struct mempolicy *new) +/* + * Apply policy to a single VMA + * This must be called with the mmap_sem held for writing. + */ +static int vma_replace_policy(struct vm_area_struct *vma, + struct mempolicy *pol) { - int err = 0; - struct mempolicy *old = vma->vm_policy; + int err; + struct mempolicy *old; + struct mempolicy *new; pr_debug("vma %lx-%lx/%lx vm_ops %p vm_file %p set_policy %p\n", vma->vm_start, vma->vm_end, vma->vm_pgoff, vma->vm_ops, vma->vm_file, vma->vm_ops ? vma->vm_ops->set_policy : NULL); - if (vma->vm_ops && vma->vm_ops->set_policy) + new = mpol_dup(pol); + if (IS_ERR(new)) + return PTR_ERR(new); + + if (vma->vm_ops && vma->vm_ops->set_policy) { err = vma->vm_ops->set_policy(vma, new); - if (!err) { - mpol_get(new); - vma->vm_policy = new; - mpol_put(old); + if (err) + goto err_out; } + + old = vma->vm_policy; + vma->vm_policy = new; /* protected by mmap_sem */ + mpol_put(old); + + return 0; + err_out: + mpol_put(new); return err; } @@ -676,7 +691,7 @@ static int mbind_range(struct mm_struct *mm, unsigned long start, if (err) goto out; } - err = policy_vma(vma, new_pol); + err = vma_replace_policy(vma, new_pol); if (err) goto out; } @@ -2153,15 +2168,24 @@ static void sp_delete(struct shared_policy *sp, struct sp_node *n) static struct sp_node *sp_alloc(unsigned long start, unsigned long end, struct mempolicy *pol) { - struct sp_node *n = kmem_cache_alloc(sn_cache, GFP_KERNEL); + struct sp_node *n; + struct mempolicy *newpol; + n = kmem_cache_alloc(sn_cache, GFP_KERNEL); if (!n) return NULL; + + newpol = mpol_dup(pol); + if (IS_ERR(newpol)) { + kmem_cache_free(sn_cache, n); + return NULL; + } + newpol->flags |= MPOL_F_SHARED; + n->start = start; n->end = end; - mpol_get(pol); - pol->flags |= MPOL_F_SHARED; /* for unref */ - n->policy = pol; + n->policy = newpol; + return n; }