[PATCH 1/2] mm: gup: add get_user_pages_locked and get_user_pages_unlocked

From: Andrea Arcangeli
Date: Fri Sep 26 2014 - 12:36:53 EST


We can leverage the VM_FAULT_RETRY functionality in the page fault
paths better by using either get_user_pages_locked or
get_user_pages_unlocked.

The former allow conversion of get_user_pages invocations that will
have to pass a "&locked" parameter to know if the mmap_sem was dropped
during the call. Example from:

down_read(&mm->mmap_sem);
do_something()
get_user_pages(tsk, mm, ..., pages, NULL);
up_read(&mm->mmap_sem);

to:

int locked = 1;
down_read(&mm->mmap_sem);
do_something()
get_user_pages_locked(tsk, mm, ..., pages, &locked);
if (locked)
up_read(&mm->mmap_sem);

The latter is suitable only as a drop in replacement of the form:

down_read(&mm->mmap_sem);
get_user_pages(tsk, mm, ..., pages, NULL);
up_read(&mm->mmap_sem);

into:

get_user_pages_unlocked(tsk, mm, ..., pages);

Where tsk, mm, the intermediate "..." paramters and "pages" can be any
value as before. Just the last parameter of get_user_pages (vmas) must
be NULL for get_user_pages_locked|unlocked to be usable (the latter
original form wouldn't have been safe anyway if vmas wasn't null, for
the former we just make it explicit by dropping the parameter).

If vmas is not NULL these two methods cannot be used.

This patch then applies the new forms in various places, in some case
also replacing it with get_user_pages_fast whenever tsk and mm are
current and current->mm. get_user_pages_unlocked varies from
get_user_pages_fast only if mm is not current->mm (like when
get_user_pages works on some other process mm). Whenever tsk and mm
matches current and current->mm get_user_pages_fast must always be
used to increase performance and get the page lockless (only with irq
disabled).

Signed-off-by: Andrea Arcangeli <aarcange@xxxxxxxxxx>
---
arch/mips/mm/gup.c | 8 +-
arch/powerpc/mm/gup.c | 6 +-
arch/s390/kvm/kvm-s390.c | 4 +-
arch/s390/mm/gup.c | 6 +-
arch/sh/mm/gup.c | 6 +-
arch/sparc/mm/gup.c | 6 +-
arch/x86/mm/gup.c | 7 +-
drivers/dma/iovlock.c | 10 +--
drivers/iommu/amd_iommu_v2.c | 6 +-
drivers/media/pci/ivtv/ivtv-udma.c | 6 +-
drivers/misc/sgi-gru/grufault.c | 3 +-
drivers/scsi/st.c | 10 +--
drivers/video/fbdev/pvr2fb.c | 5 +-
include/linux/mm.h | 7 ++
mm/gup.c | 147 ++++++++++++++++++++++++++++++++++---
mm/mempolicy.c | 2 +-
mm/nommu.c | 23 ++++++
mm/process_vm_access.c | 7 +-
mm/util.c | 10 +--
net/ceph/pagevec.c | 9 +--
20 files changed, 200 insertions(+), 88 deletions(-)

diff --git a/arch/mips/mm/gup.c b/arch/mips/mm/gup.c
index 06ce17c..20884f5 100644
--- a/arch/mips/mm/gup.c
+++ b/arch/mips/mm/gup.c
@@ -301,11 +301,9 @@ slow_irqon:
start += nr << PAGE_SHIFT;
pages += nr;

- down_read(&mm->mmap_sem);
- ret = get_user_pages(current, mm, start,
- (end - start) >> PAGE_SHIFT,
- write, 0, pages, NULL);
- up_read(&mm->mmap_sem);
+ ret = get_user_pages_unlocked(current, mm, start,
+ (end - start) >> PAGE_SHIFT,
+ write, 0, pages);

/* Have to be a bit careful with return values */
if (nr > 0) {
diff --git a/arch/powerpc/mm/gup.c b/arch/powerpc/mm/gup.c
index d874668..b70c34a 100644
--- a/arch/powerpc/mm/gup.c
+++ b/arch/powerpc/mm/gup.c
@@ -215,10 +215,8 @@ int get_user_pages_fast(unsigned long start, int nr_pages, int write,
start += nr << PAGE_SHIFT;
pages += nr;

- down_read(&mm->mmap_sem);
- ret = get_user_pages(current, mm, start,
- nr_pages - nr, write, 0, pages, NULL);
- up_read(&mm->mmap_sem);
+ ret = get_user_pages_unlocked(current, mm, start,
+ nr_pages - nr, write, 0, pages);

/* Have to be a bit careful with return values */
if (nr > 0) {
diff --git a/arch/s390/kvm/kvm-s390.c b/arch/s390/kvm/kvm-s390.c
index 81b0e11..37ca29a 100644
--- a/arch/s390/kvm/kvm-s390.c
+++ b/arch/s390/kvm/kvm-s390.c
@@ -1092,9 +1092,7 @@ long kvm_arch_fault_in_page(struct kvm_vcpu *vcpu, gpa_t gpa, int writable)
hva = gmap_fault(gpa, vcpu->arch.gmap);
if (IS_ERR_VALUE(hva))
return (long)hva;
- down_read(&mm->mmap_sem);
- rc = get_user_pages(current, mm, hva, 1, writable, 0, NULL, NULL);
- up_read(&mm->mmap_sem);
+ rc = get_user_pages_unlocked(current, mm, hva, 1, writable, 0, NULL);

return rc < 0 ? rc : 0;
}
diff --git a/arch/s390/mm/gup.c b/arch/s390/mm/gup.c
index 639fce46..5c586c7 100644
--- a/arch/s390/mm/gup.c
+++ b/arch/s390/mm/gup.c
@@ -235,10 +235,8 @@ int get_user_pages_fast(unsigned long start, int nr_pages, int write,
/* Try to get the remaining pages with get_user_pages */
start += nr << PAGE_SHIFT;
pages += nr;
- down_read(&mm->mmap_sem);
- ret = get_user_pages(current, mm, start,
- nr_pages - nr, write, 0, pages, NULL);
- up_read(&mm->mmap_sem);
+ ret = get_user_pages_unlocked(current, mm, start,
+ nr_pages - nr, write, 0, pages);
/* Have to be a bit careful with return values */
if (nr > 0)
ret = (ret < 0) ? nr : ret + nr;
diff --git a/arch/sh/mm/gup.c b/arch/sh/mm/gup.c
index 37458f3..e15f52a 100644
--- a/arch/sh/mm/gup.c
+++ b/arch/sh/mm/gup.c
@@ -257,10 +257,8 @@ slow_irqon:
start += nr << PAGE_SHIFT;
pages += nr;

- down_read(&mm->mmap_sem);
- ret = get_user_pages(current, mm, start,
- (end - start) >> PAGE_SHIFT, write, 0, pages, NULL);
- up_read(&mm->mmap_sem);
+ ret = get_user_pages_unlocked(current, mm, start,
+ (end - start) >> PAGE_SHIFT, write, 0, pages);

/* Have to be a bit careful with return values */
if (nr > 0) {
diff --git a/arch/sparc/mm/gup.c b/arch/sparc/mm/gup.c
index 1aed043..fa7de7d 100644
--- a/arch/sparc/mm/gup.c
+++ b/arch/sparc/mm/gup.c
@@ -219,10 +219,8 @@ slow:
start += nr << PAGE_SHIFT;
pages += nr;

- down_read(&mm->mmap_sem);
- ret = get_user_pages(current, mm, start,
- (end - start) >> PAGE_SHIFT, write, 0, pages, NULL);
- up_read(&mm->mmap_sem);
+ ret = get_user_pages_unlocked(current, mm, start,
+ (end - start) >> PAGE_SHIFT, write, 0, pages);

/* Have to be a bit careful with return values */
if (nr > 0) {
diff --git a/arch/x86/mm/gup.c b/arch/x86/mm/gup.c
index 207d9aef..2ab183b 100644
--- a/arch/x86/mm/gup.c
+++ b/arch/x86/mm/gup.c
@@ -388,10 +388,9 @@ slow_irqon:
start += nr << PAGE_SHIFT;
pages += nr;

- down_read(&mm->mmap_sem);
- ret = get_user_pages(current, mm, start,
- (end - start) >> PAGE_SHIFT, write, 0, pages, NULL);
- up_read(&mm->mmap_sem);
+ ret = get_user_pages_unlocked(current, mm, start,
+ (end - start) >> PAGE_SHIFT,
+ write, 0, pages);

/* Have to be a bit careful with return values */
if (nr > 0) {
diff --git a/drivers/dma/iovlock.c b/drivers/dma/iovlock.c
index bb48a57..12ea7c3 100644
--- a/drivers/dma/iovlock.c
+++ b/drivers/dma/iovlock.c
@@ -95,17 +95,11 @@ struct dma_pinned_list *dma_pin_iovec_pages(struct iovec *iov, size_t len)
pages += page_list->nr_pages;

/* pin pages down */
- down_read(&current->mm->mmap_sem);
- ret = get_user_pages(
- current,
- current->mm,
+ ret = get_user_pages_fast(
(unsigned long) iov[i].iov_base,
page_list->nr_pages,
1, /* write */
- 0, /* force */
- page_list->pages,
- NULL);
- up_read(&current->mm->mmap_sem);
+ page_list->pages);

if (ret != page_list->nr_pages)
goto unpin;
diff --git a/drivers/iommu/amd_iommu_v2.c b/drivers/iommu/amd_iommu_v2.c
index 5f578e8..6963b73 100644
--- a/drivers/iommu/amd_iommu_v2.c
+++ b/drivers/iommu/amd_iommu_v2.c
@@ -519,10 +519,8 @@ static void do_fault(struct work_struct *work)

write = !!(fault->flags & PPR_FAULT_WRITE);

- down_read(&fault->state->mm->mmap_sem);
- npages = get_user_pages(NULL, fault->state->mm,
- fault->address, 1, write, 0, &page, NULL);
- up_read(&fault->state->mm->mmap_sem);
+ npages = get_user_pages_unlocked(NULL, fault->state->mm,
+ fault->address, 1, write, 0, &page);

if (npages == 1) {
put_page(page);
diff --git a/drivers/media/pci/ivtv/ivtv-udma.c b/drivers/media/pci/ivtv/ivtv-udma.c
index 7338cb2..96d866b 100644
--- a/drivers/media/pci/ivtv/ivtv-udma.c
+++ b/drivers/media/pci/ivtv/ivtv-udma.c
@@ -124,10 +124,8 @@ int ivtv_udma_setup(struct ivtv *itv, unsigned long ivtv_dest_addr,
}

/* Get user pages for DMA Xfer */
- down_read(&current->mm->mmap_sem);
- err = get_user_pages(current, current->mm,
- user_dma.uaddr, user_dma.page_count, 0, 1, dma->map, NULL);
- up_read(&current->mm->mmap_sem);
+ err = get_user_pages_unlocked(current, current->mm,
+ user_dma.uaddr, user_dma.page_count, 0, 1, dma->map);

if (user_dma.page_count != err) {
IVTV_DEBUG_WARN("failed to map user pages, returned %d instead of %d\n",
diff --git a/drivers/misc/sgi-gru/grufault.c b/drivers/misc/sgi-gru/grufault.c
index f74fc0c..cd20669 100644
--- a/drivers/misc/sgi-gru/grufault.c
+++ b/drivers/misc/sgi-gru/grufault.c
@@ -198,8 +198,7 @@ static int non_atomic_pte_lookup(struct vm_area_struct *vma,
#else
*pageshift = PAGE_SHIFT;
#endif
- if (get_user_pages
- (current, current->mm, vaddr, 1, write, 0, &page, NULL) <= 0)
+ if (get_user_pages_fast(vaddr, 1, write, &page) <= 0)
return -EFAULT;
*paddr = page_to_phys(page);
put_page(page);
diff --git a/drivers/scsi/st.c b/drivers/scsi/st.c
index aff9689..c89dcfa 100644
--- a/drivers/scsi/st.c
+++ b/drivers/scsi/st.c
@@ -4536,18 +4536,12 @@ static int sgl_map_user_pages(struct st_buffer *STbp,
return -ENOMEM;

/* Try to fault in all of the necessary pages */
- down_read(&current->mm->mmap_sem);
/* rw==READ means read from drive, write into memory area */
- res = get_user_pages(
- current,
- current->mm,
+ res = get_user_pages_fast(
uaddr,
nr_pages,
rw == READ,
- 0, /* don't force */
- pages,
- NULL);
- up_read(&current->mm->mmap_sem);
+ pages);

/* Errors and no page mapped should return here */
if (res < nr_pages)
diff --git a/drivers/video/fbdev/pvr2fb.c b/drivers/video/fbdev/pvr2fb.c
index 167cfff..ff81f65 100644
--- a/drivers/video/fbdev/pvr2fb.c
+++ b/drivers/video/fbdev/pvr2fb.c
@@ -686,10 +686,7 @@ static ssize_t pvr2fb_write(struct fb_info *info, const char *buf,
if (!pages)
return -ENOMEM;

- down_read(&current->mm->mmap_sem);
- ret = get_user_pages(current, current->mm, (unsigned long)buf,
- nr_pages, WRITE, 0, pages, NULL);
- up_read(&current->mm->mmap_sem);
+ ret = get_user_pages_fast((unsigned long)buf, nr_pages, WRITE, pages);

if (ret < nr_pages) {
nr_pages = ret;
diff --git a/include/linux/mm.h b/include/linux/mm.h
index 32ba786..69f692d 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1197,6 +1197,13 @@ long get_user_pages(struct task_struct *tsk, struct mm_struct *mm,
unsigned long start, unsigned long nr_pages,
int write, int force, struct page **pages,
struct vm_area_struct **vmas);
+long get_user_pages_locked(struct task_struct *tsk, struct mm_struct *mm,
+ unsigned long start, unsigned long nr_pages,
+ int write, int force, struct page **pages,
+ int *locked);
+long get_user_pages_unlocked(struct task_struct *tsk, struct mm_struct *mm,
+ unsigned long start, unsigned long nr_pages,
+ int write, int force, struct page **pages);
int get_user_pages_fast(unsigned long start, int nr_pages, int write,
struct page **pages);
struct kvec;
diff --git a/mm/gup.c b/mm/gup.c
index 91d044b..19e17ab 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -576,6 +576,134 @@ int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
return 0;
}

+static inline long __get_user_pages_locked(struct task_struct *tsk,
+ struct mm_struct *mm,
+ unsigned long start,
+ unsigned long nr_pages,
+ int write, int force,
+ struct page **pages,
+ struct vm_area_struct **vmas,
+ int *locked,
+ bool immediate_unlock)
+{
+ int flags = FOLL_TOUCH;
+ long ret, pages_done;
+ bool lock_dropped;
+
+ if (locked) {
+ /* if VM_FAULT_RETRY can be returned, vmas become invalid */
+ BUG_ON(vmas);
+ /* check caller initialized locked */
+ BUG_ON(*locked != 1);
+ } else {
+ /*
+ * Not really important, the value is irrelevant if
+ * locked is NULL, but BUILD_BUG_ON costs nothing.
+ */
+ BUILD_BUG_ON(immediate_unlock);
+ }
+
+ if (pages)
+ flags |= FOLL_GET;
+ if (write)
+ flags |= FOLL_WRITE;
+ if (force)
+ flags |= FOLL_FORCE;
+
+ pages_done = 0;
+ lock_dropped = false;
+ for (;;) {
+ ret = __get_user_pages(tsk, mm, start, nr_pages, flags, pages,
+ vmas, locked);
+ if (!locked)
+ /* VM_FAULT_RETRY couldn't trigger, bypass */
+ return ret;
+
+ /* VM_FAULT_RETRY cannot return errors */
+ if (!*locked) {
+ BUG_ON(ret < 0);
+ BUG_ON(nr_pages == 1 && ret);
+ }
+
+ if (!pages)
+ /* If it's a prefault don't insist harder */
+ return ret;
+
+ if (ret > 0) {
+ nr_pages -= ret;
+ pages_done += ret;
+ if (!nr_pages)
+ break;
+ }
+ if (*locked) {
+ /* VM_FAULT_RETRY didn't trigger */
+ if (!pages_done)
+ pages_done = ret;
+ break;
+ }
+ /* VM_FAULT_RETRY triggered, so seek to the faulting offset */
+ pages += ret;
+ start += ret << PAGE_SHIFT;
+
+ /*
+ * Repeat on the address that fired VM_FAULT_RETRY
+ * without FAULT_FLAG_ALLOW_RETRY but with
+ * FAULT_FLAG_TRIED.
+ */
+ *locked = 1;
+ lock_dropped = true;
+ down_read(&mm->mmap_sem);
+ ret = __get_user_pages(tsk, mm, start, nr_pages, flags | FOLL_TRIED,
+ pages, NULL, NULL);
+ if (ret != 1) {
+ BUG_ON(ret > 1);
+ if (!pages_done)
+ pages_done = ret;
+ break;
+ }
+ nr_pages--;
+ pages_done++;
+ if (!nr_pages)
+ break;
+ pages++;
+ start += PAGE_SIZE;
+ }
+ if (!immediate_unlock && lock_dropped && *locked) {
+ /*
+ * We must let the caller know we temporarily dropped the lock
+ * and so the critical section protected by it was lost.
+ */
+ up_read(&mm->mmap_sem);
+ *locked = 0;
+ }
+ return pages_done;
+}
+
+long get_user_pages_locked(struct task_struct *tsk, struct mm_struct *mm,
+ unsigned long start, unsigned long nr_pages,
+ int write, int force, struct page **pages,
+ int *locked)
+{
+ return __get_user_pages_locked(tsk, mm, start, nr_pages, write, force,
+ pages, NULL, locked, false);
+}
+EXPORT_SYMBOL(get_user_pages_locked);
+
+long get_user_pages_unlocked(struct task_struct *tsk, struct mm_struct *mm,
+ unsigned long start, unsigned long nr_pages,
+ int write, int force, struct page **pages)
+{
+ long ret;
+ int locked = 1;
+ down_read(&mm->mmap_sem);
+ ret = __get_user_pages_locked(tsk, mm, start, nr_pages, write, force,
+ pages, NULL, &locked, true);
+ if (locked)
+ up_read(&mm->mmap_sem);
+ return ret;
+}
+EXPORT_SYMBOL(get_user_pages_unlocked);
+
/*
* get_user_pages() - pin user pages in memory
* @tsk: the task_struct to use for page fault accounting, or
@@ -625,22 +753,19 @@ int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
* use the correct cache flushing APIs.
*
* See also get_user_pages_fast, for performance critical applications.
+ *
+ * get_user_pages should be gradually obsoleted in favor of
+ * get_user_pages_locked|unlocked. Nothing should use get_user_pages
+ * because it cannot pass FAULT_FLAG_ALLOW_RETRY to handle_mm_fault in
+ * turn disabling the userfaultfd feature (after that "inline" can be
+ * cleaned up from get_user_pages_locked).
*/
long get_user_pages(struct task_struct *tsk, struct mm_struct *mm,
unsigned long start, unsigned long nr_pages, int write,
int force, struct page **pages, struct vm_area_struct **vmas)
{
- int flags = FOLL_TOUCH;
-
- if (pages)
- flags |= FOLL_GET;
- if (write)
- flags |= FOLL_WRITE;
- if (force)
- flags |= FOLL_FORCE;
-
- return __get_user_pages(tsk, mm, start, nr_pages, flags, pages, vmas,
- NULL);
+ return __get_user_pages_locked(tsk, mm, start, nr_pages, write, force,
+ pages, vmas, NULL, false);
}
EXPORT_SYMBOL(get_user_pages);

diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 8f5330d..6606c10 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -881,7 +881,7 @@ static int lookup_node(struct mm_struct *mm, unsigned long addr)
struct page *p;
int err;

- err = get_user_pages(current, mm, addr & PAGE_MASK, 1, 0, 0, &p, NULL);
+ err = get_user_pages_fast(addr & PAGE_MASK, 1, 0, &p);
if (err >= 0) {
err = page_to_nid(p);
put_page(p);
diff --git a/mm/nommu.c b/mm/nommu.c
index a881d96..8a06341 100644
--- a/mm/nommu.c
+++ b/mm/nommu.c
@@ -213,6 +213,29 @@ long get_user_pages(struct task_struct *tsk, struct mm_struct *mm,
}
EXPORT_SYMBOL(get_user_pages);

+long get_user_pages_locked(struct task_struct *tsk, struct mm_struct *mm,
+ unsigned long start, unsigned long nr_pages,
+ int write, int force, struct page **pages,
+ int *locked)
+{
+ return get_user_pages(tsk, mm, start, nr_pages, write, force,
+ pages, NULL);
+}
+EXPORT_SYMBOL(get_user_pages_locked);
+
+long get_user_pages_unlocked(struct task_struct *tsk, struct mm_struct *mm,
+ unsigned long start, unsigned long nr_pages,
+ int write, int force, struct page **pages)
+{
+ long ret;
+ down_read(&mm->mmap_sem);
+ ret = get_user_pages(tsk, mm, start, nr_pages, write, force,
+ pages, NULL);
+ up_read(&mm->mmap_sem);
+ return ret;
+}
+EXPORT_SYMBOL(get_user_pages_unlocked);
+
/**
* follow_pfn - look up PFN at a user virtual address
* @vma: memory mapping
diff --git a/mm/process_vm_access.c b/mm/process_vm_access.c
index 5077afc..b159769 100644
--- a/mm/process_vm_access.c
+++ b/mm/process_vm_access.c
@@ -99,11 +99,8 @@ static int process_vm_rw_single_vec(unsigned long addr,
size_t bytes;

/* Get the pages we're interested in */
- down_read(&mm->mmap_sem);
- pages = get_user_pages(task, mm, pa, pages,
- vm_write, 0, process_pages, NULL);
- up_read(&mm->mmap_sem);
-
+ pages = get_user_pages_unlocked(task, mm, pa, pages,
+ vm_write, 0, process_pages);
if (pages <= 0)
return -EFAULT;

diff --git a/mm/util.c b/mm/util.c
index 093c973..1b93f2d 100644
--- a/mm/util.c
+++ b/mm/util.c
@@ -247,14 +247,8 @@ int __weak get_user_pages_fast(unsigned long start,
int nr_pages, int write, struct page **pages)
{
struct mm_struct *mm = current->mm;
- int ret;
-
- down_read(&mm->mmap_sem);
- ret = get_user_pages(current, mm, start, nr_pages,
- write, 0, pages, NULL);
- up_read(&mm->mmap_sem);
-
- return ret;
+ return get_user_pages_unlocked(current, mm, start, nr_pages,
+ write, 0, pages);
}
EXPORT_SYMBOL_GPL(get_user_pages_fast);

diff --git a/net/ceph/pagevec.c b/net/ceph/pagevec.c
index 5550130..5504783 100644
--- a/net/ceph/pagevec.c
+++ b/net/ceph/pagevec.c
@@ -23,17 +23,16 @@ struct page **ceph_get_direct_page_vector(const void __user *data,
if (!pages)
return ERR_PTR(-ENOMEM);

- down_read(&current->mm->mmap_sem);
while (got < num_pages) {
- rc = get_user_pages(current, current->mm,
- (unsigned long)data + ((unsigned long)got * PAGE_SIZE),
- num_pages - got, write_page, 0, pages + got, NULL);
+ rc = get_user_pages_fast((unsigned long)data +
+ ((unsigned long)got * PAGE_SIZE),
+ num_pages - got,
+ write_page, pages + got);
if (rc < 0)
break;
BUG_ON(rc == 0);
got += rc;
}
- up_read(&current->mm->mmap_sem);
if (rc < 0)
goto fail;
return pages;


Then to make an example your patch would have become:

===