KSM WARN_ON_ONCE(page_mapped(page)) in remove_stable_node()

From: Andrey Ryabinin
Date: Wed Nov 13 2019 - 05:34:32 EST


When remove_stable_node() races with __mmput() and squeezes in between ksm_exit() and exit_mmap(),
the WARN_ON_ONCE(page_mapped(page)) in remove_stable_node() could be triggered.

Should we just remove the warning? It seems to be safe to do, all callers are able to handle -EBUSY,
or there is a better way to fix this?



It's easily reproducible with the following script:
(ksm_test.c attached)

#!/bin/bash

gcc -lnuma -O2 ksm_test.c -o ksm_test
echo 1 > /sys/kernel/mm/ksm/run
./ksm_test &
sleep 1
echo 2 > /sys/kernel/mm/ksm/run

and the patch bellow which provokes that race.

---
include/linux/ksm.h | 4 +++-
include/linux/mm_types.h | 1 +
kernel/fork.c | 4 ++++
3 files changed, 8 insertions(+), 1 deletion(-)

diff --git a/include/linux/ksm.h b/include/linux/ksm.h
index e48b1e453ff5..18384ea472f8 100644
--- a/include/linux/ksm.h
+++ b/include/linux/ksm.h
@@ -33,8 +33,10 @@ static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)

static inline void ksm_exit(struct mm_struct *mm)
{
- if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
+ if (test_bit(MMF_VM_MERGEABLE, &mm->flags)) {
__ksm_exit(mm);
+ mm->ksm_wait = 1;
+ }
}

/*
diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index 270aa8fd2800..3df8290528c2 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -463,6 +463,7 @@ struct mm_struct {

/* Architecture-specific MM context */
mm_context_t context;
+ unsigned long ksm_wait;

unsigned long flags; /* Must use atomic bitops to access */

diff --git a/kernel/fork.c b/kernel/fork.c
index 5fb7e1fa0b05..be6ef4e046f0 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -1074,6 +1074,10 @@ static inline void __mmput(struct mm_struct *mm)
uprobe_clear_state(mm);
exit_aio(mm);
ksm_exit(mm);
+
+ if (mm->ksm_wait)
+ schedule_timeout_uninterruptible(10*HZ);
+
khugepaged_exit(mm); /* must run before exit_mmap */
exit_mmap(mm);
mm_put_huge_zero_page(mm);
--
2.23.0
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <sys/mman.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdio.h>
#include <numaif.h>
#include <sys/types.h>
#include <sys/wait.h>


//#define NR_NODES 4
#define NR_NODES 1

#define MAP_SIZE 4096

#define NR_THREADS 1024

pid_t pids[NR_THREADS];

int merge_and_migrate(void)
{
void *p;
unsigned long rnd;
unsigned long old_node, new_node;
pid_t p_pid, pid;
int j;

p = mmap(NULL, MAP_SIZE, PROT_READ|PROT_WRITE,
MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
if (p == MAP_FAILED)
perror("mmap"), exit(1);

memset(p, 0xff, MAP_SIZE);
if (madvise(p, MAP_SIZE, MADV_MERGEABLE))
perror("madvise"), exit(1);
sleep(1000000);

while (1) {
sleep(0);
rnd = rand() % 2;
switch (rnd) {
case 0: {
rnd = rand() % 128;
memset(p, rnd, MAP_SIZE);
break;
}
case 1: {
j = rand()%NR_NODES;
old_node = 1 << j;
new_node = 1<<((j+1)%NR_NODES);

migrate_pages(0, NR_NODES, &old_node, &new_node);
break;
}
}
}
return 0;
}

int main(void)
{
int i,ret,j;
pid_t pid;
int wstatus;
unsigned long old_node, new_node;

for (i = 0; i < NR_THREADS; i++) {
pid = fork();
if (pid < 0) {
perror("fork");
return 1;
}
if (pid) {
pids[i] = pid;
continue;
} else
merge_and_migrate();
}

while (1) {
pid = waitpid(-1, &wstatus, WNOHANG);
if (pid < 0) {
perror("waitpid failed");
return 1;
}
if (pid) {
for (i = 0; i< NR_THREADS; i++) {
if (pids[i] == pid) {
pid = fork();
if (pid < 0) {
perror("fork in while");
return 1;
}
if (pid) {
pids[i] = pid;
break;
} else
merge_and_migrate();
}
}
continue; /*while(1)*/
}
i = rand()%NR_THREADS;
kill(pids[i], SIGKILL);
}
return 0;
}