[PATCH v4 17/49] mm: Change mprotect_fixup to vma iterator

From: Liam R. Howlett
Date: Fri Jan 20 2023 - 11:29:55 EST


From: "Liam R. Howlett" <Liam.Howlett@xxxxxxxxxx>

Use the vma iterator so that the iterator can be invalidated or updated
to avoid each caller doing so.

Signed-off-by: Liam R. Howlett <Liam.Howlett@xxxxxxxxxx>
---
fs/exec.c | 5 ++++-
include/linux/mm.h | 6 +++---
mm/mprotect.c | 47 ++++++++++++++++++++++------------------------
3 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/fs/exec.c b/fs/exec.c
index ab913243a367..b98647eeae9f 100644
--- a/fs/exec.c
+++ b/fs/exec.c
@@ -758,6 +758,7 @@ int setup_arg_pages(struct linux_binprm *bprm,
unsigned long stack_expand;
unsigned long rlim_stack;
struct mmu_gather tlb;
+ struct vma_iterator vmi;

#ifdef CONFIG_STACK_GROWSUP
/* Limit stack size */
@@ -812,8 +813,10 @@ int setup_arg_pages(struct linux_binprm *bprm,
vm_flags |= mm->def_flags;
vm_flags |= VM_STACK_INCOMPLETE_SETUP;

+ vma_iter_init(&vmi, mm, vma->vm_start);
+
tlb_gather_mmu(&tlb, mm);
- ret = mprotect_fixup(&tlb, vma, &prev, vma->vm_start, vma->vm_end,
+ ret = mprotect_fixup(&vmi, &tlb, vma, &prev, vma->vm_start, vma->vm_end,
vm_flags);
tlb_finish_mmu(&tlb);

diff --git a/include/linux/mm.h b/include/linux/mm.h
index 956025940053..bd0017ab13f3 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -2197,9 +2197,9 @@ bool can_change_pte_writable(struct vm_area_struct *vma, unsigned long addr,
extern long change_protection(struct mmu_gather *tlb,
struct vm_area_struct *vma, unsigned long start,
unsigned long end, unsigned long cp_flags);
-extern int mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma,
- struct vm_area_struct **pprev, unsigned long start,
- unsigned long end, unsigned long newflags);
+extern int mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
+ struct vm_area_struct *vma, struct vm_area_struct **pprev,
+ unsigned long start, unsigned long end, unsigned long newflags);

/*
* doesn't attempt to fault and will return short.
diff --git a/mm/mprotect.c b/mm/mprotect.c
index 6ecdf0671b81..42ceb0548754 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -585,9 +585,9 @@ static const struct mm_walk_ops prot_none_walk_ops = {
};

int
-mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma,
- struct vm_area_struct **pprev, unsigned long start,
- unsigned long end, unsigned long newflags)
+mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
+ struct vm_area_struct *vma, struct vm_area_struct **pprev,
+ unsigned long start, unsigned long end, unsigned long newflags)
{
struct mm_struct *mm = vma->vm_mm;
unsigned long oldflags = vma->vm_flags;
@@ -642,7 +642,7 @@ mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma,
* First try to merge with previous and/or next vma.
*/
pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
- *pprev = vma_merge(mm, *pprev, start, end, newflags,
+ *pprev = vmi_vma_merge(vmi, mm, *pprev, start, end, newflags,
vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
vma->vm_userfaultfd_ctx, anon_vma_name(vma));
if (*pprev) {
@@ -654,13 +654,13 @@ mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma,
*pprev = vma;

if (start != vma->vm_start) {
- error = split_vma(mm, vma, start, 1);
+ error = vmi_split_vma(vmi, mm, vma, start, 1);
if (error)
goto fail;
}

if (end != vma->vm_end) {
- error = split_vma(mm, vma, end, 0);
+ error = vmi_split_vma(vmi, mm, vma, end, 0);
if (error)
goto fail;
}
@@ -709,7 +709,7 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
const bool rier = (current->personality & READ_IMPLIES_EXEC) &&
(prot & PROT_READ);
struct mmu_gather tlb;
- MA_STATE(mas, &current->mm->mm_mt, 0, 0);
+ struct vma_iterator vmi;

start = untagged_addr(start);

@@ -741,8 +741,8 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
if ((pkey != -1) && !mm_pkey_is_allocated(current->mm, pkey))
goto out;

- mas_set(&mas, start);
- vma = mas_find(&mas, ULONG_MAX);
+ vma_iter_init(&vmi, current->mm, start);
+ vma = vma_find(&vmi, end);
error = -ENOMEM;
if (!vma)
goto out;
@@ -765,18 +765,22 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
}
}

+ prev = vma_prev(&vmi);
if (start > vma->vm_start)
prev = vma;
- else
- prev = mas_prev(&mas, 0);

tlb_gather_mmu(&tlb, current->mm);
- for (nstart = start ; ; ) {
+ nstart = start;
+ tmp = vma->vm_start;
+ for_each_vma_range(vmi, vma, end) {
unsigned long mask_off_old_flags;
unsigned long newflags;
int new_vma_pkey;

- /* Here we know that vma->vm_start <= nstart < vma->vm_end. */
+ if (vma->vm_start != tmp) {
+ error = -ENOMEM;
+ break;
+ }

/* Does the application expect PROT_READ to imply PROT_EXEC */
if (rier && (vma->vm_flags & VM_MAYEXEC))
@@ -819,25 +823,18 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
break;
}

- error = mprotect_fixup(&tlb, vma, &prev, nstart, tmp, newflags);
+ error = mprotect_fixup(&vmi, &tlb, vma, &prev, nstart, tmp, newflags);
if (error)
break;

nstart = tmp;
-
- if (nstart < prev->vm_end)
- nstart = prev->vm_end;
- if (nstart >= end)
- break;
-
- vma = find_vma(current->mm, prev->vm_end);
- if (!vma || vma->vm_start != nstart) {
- error = -ENOMEM;
- break;
- }
prot = reqprot;
}
tlb_finish_mmu(&tlb);
+
+ if (vma_iter_end(&vmi) < end)
+ error = -ENOMEM;
+
out:
mmap_write_unlock(current->mm);
return error;
--
2.35.1