[PATCH v3 1/1] kthread: allocate kthread structure using kmalloc

From: Roman Pen
Date: Tue Oct 25 2016 - 07:06:05 EST


This patch avoids allocation of kthread structure on a stack, and simply
uses kmalloc. Allocation on a stack became a huge problem (with memory
corruption and all other not nice consequences) after the following commit
2deb4be28077 ("x86/dumpstack: When OOPSing, rewind the stack before do_exit()")
by Andy Lutomirski, which rewinds the stack on oops, thus ooopsed kthread
steps on a garbage memory while completion of task->vfork_done structure
on the following path:

oops_end()
rewind_stack_do_exit()
exit_mm()
mm_release()
complete_vfork_done()

Also in this patch two structures 'struct kthread_create_info' and
'struct kthread' are merged into one 'struct kthread' and its freeing
is controlled by a reference counter.

The last reference on kthread is put from a task work, the callback,
which is invoked from do_exit(). The major thing is that the last
put is happens *after* completion_vfork_done() is invoked.

Signed-off-by: Roman Pen <roman.penyaev@xxxxxxxxxxxxxxxx>
Cc: Andy Lutomirski <luto@xxxxxxxxxx>
Cc: Oleg Nesterov <oleg@xxxxxxxxxx>
Cc: Peter Zijlstra <peterz@xxxxxxxxxxxxx>
Cc: Thomas Gleixner <tglx@xxxxxxxxxxxxx>
Cc: Ingo Molnar <mingo@xxxxxxxxxx>
Cc: Tejun Heo <tj@xxxxxxxxxx>
Cc: linux-kernel@xxxxxxxxxxxxxxx
---
v3:
o handle to_live_kthread() calls, which should increase a kthread
ref or return NULL. Function was renamed to to_live_kthread_and_get().
o minor comments tweaks.

v2:
o let x86/kernel/dumpstack.c rewind a stack, but do not use a stack
for a structure allocation.

kernel/kthread.c | 198 ++++++++++++++++++++++++++++++++-----------------------
1 file changed, 117 insertions(+), 81 deletions(-)

diff --git a/kernel/kthread.c b/kernel/kthread.c
index 4ab4c3766a80..e8adc10556e0 100644
--- a/kernel/kthread.c
+++ b/kernel/kthread.c
@@ -18,14 +18,19 @@
#include <linux/freezer.h>
#include <linux/ptrace.h>
#include <linux/uaccess.h>
+#include <linux/task_work.h>
#include <trace/events/sched.h>

static DEFINE_SPINLOCK(kthread_create_lock);
static LIST_HEAD(kthread_create_list);
struct task_struct *kthreadd_task;

-struct kthread_create_info
-{
+struct kthread {
+ struct list_head list;
+ unsigned long flags;
+ unsigned int cpu;
+ atomic_t refs;
+
/* Information passed to kthread() from kthreadd. */
int (*threadfn)(void *data);
void *data;
@@ -33,15 +38,9 @@ struct kthread_create_info

/* Result passed back to kthread_create() from kthreadd. */
struct task_struct *result;
- struct completion *done;

- struct list_head list;
-};
-
-struct kthread {
- unsigned long flags;
- unsigned int cpu;
- void *data;
+ struct callback_head put_work;
+ struct completion *started;
struct completion parked;
struct completion exited;
};
@@ -56,17 +55,49 @@ enum KTHREAD_BITS {
#define __to_kthread(vfork) \
container_of(vfork, struct kthread, exited)

+static inline void get_kthread(struct kthread *kthread)
+{
+ BUG_ON(atomic_read(&kthread->refs) <= 0);
+ atomic_inc(&kthread->refs);
+}
+
+static inline void put_kthread(struct kthread *kthread)
+{
+ BUG_ON(atomic_read(&kthread->refs) <= 0);
+ if (atomic_dec_and_test(&kthread->refs))
+ kfree(kthread);
+}
+
+/**
+ * put_kthread_cb - is called from do_exit() and does likely
+ * the final put.
+ */
+static void put_kthread_cb(struct callback_head *work)
+{
+ struct kthread *kthread;
+
+ kthread = container_of(work, struct kthread, put_work);
+ put_kthread(kthread);
+}
+
static inline struct kthread *to_kthread(struct task_struct *k)
{
return __to_kthread(k->vfork_done);
}

-static struct kthread *to_live_kthread(struct task_struct *k)
+static struct kthread *to_live_kthread_and_get(struct task_struct *k)
{
- struct completion *vfork = ACCESS_ONCE(k->vfork_done);
- if (likely(vfork) && try_get_task_stack(k))
- return __to_kthread(vfork);
- return NULL;
+ struct kthread *kthread = NULL;
+
+ BUG_ON(!(k->flags & PF_KTHREAD));
+ task_lock(k);
+ if (likely(k->vfork_done)) {
+ kthread = __to_kthread(k->vfork_done);
+ get_kthread(kthread);
+ }
+ task_unlock(k);
+
+ return kthread;
}

/**
@@ -174,41 +205,37 @@ void kthread_parkme(void)
}
EXPORT_SYMBOL_GPL(kthread_parkme);

-static int kthread(void *_create)
+static int kthreadfn(void *_self)
{
- /* Copy data: it's on kthread's stack */
- struct kthread_create_info *create = _create;
- int (*threadfn)(void *data) = create->threadfn;
- void *data = create->data;
- struct completion *done;
- struct kthread self;
- int ret;
-
- self.flags = 0;
- self.data = data;
- init_completion(&self.exited);
- init_completion(&self.parked);
- current->vfork_done = &self.exited;
-
- /* If user was SIGKILLed, I release the structure. */
- done = xchg(&create->done, NULL);
- if (!done) {
- kfree(create);
- do_exit(-EINTR);
+ struct completion *started;
+ struct kthread *self = _self;
+ int ret = -EINTR;
+
+ /* If user was SIGKILLed, put a ref and exit silently. */
+ started = xchg(&self->started, NULL);
+ if (!started) {
+ put_kthread(self);
+ goto exit;
}
+ /*
+ * Delegate last ref put to a task work, which will happen
+ * after 'vfork_done' completion.
+ */
+ init_task_work(&self->put_work, put_kthread_cb);
+ task_work_add(current, &self->put_work, false);
+ current->vfork_done = &self->exited;
+
/* OK, tell user we're spawned, wait for stop or wakeup */
__set_current_state(TASK_UNINTERRUPTIBLE);
- create->result = current;
- complete(done);
+ self->result = current;
+ complete(started);
schedule();

- ret = -EINTR;
-
- if (!test_bit(KTHREAD_SHOULD_STOP, &self.flags)) {
- __kthread_parkme(&self);
- ret = threadfn(data);
+ if (!test_bit(KTHREAD_SHOULD_STOP, &self->flags)) {
+ __kthread_parkme(self);
+ ret = self->threadfn(self->data);
}
- /* we can't just return, we must preserve "self" on stack */
+exit:
do_exit(ret);
}

@@ -222,25 +249,25 @@ int tsk_fork_get_node(struct task_struct *tsk)
return NUMA_NO_NODE;
}

-static void create_kthread(struct kthread_create_info *create)
+static void create_kthread(struct kthread *kthread)
{
+ struct completion *started;
int pid;

#ifdef CONFIG_NUMA
- current->pref_node_fork = create->node;
+ current->pref_node_fork = kthread->node;
#endif
/* We want our own signal handler (we take no signals by default). */
- pid = kernel_thread(kthread, create, CLONE_FS | CLONE_FILES | SIGCHLD);
+ pid = kernel_thread(kthreadfn, kthread,
+ CLONE_FS | CLONE_FILES | SIGCHLD);
if (pid < 0) {
- /* If user was SIGKILLed, I release the structure. */
- struct completion *done = xchg(&create->done, NULL);
-
- if (!done) {
- kfree(create);
- return;
+ started = xchg(&kthread->started, NULL);
+ if (started) {
+ /* The user was not SIGKILLed and wants the result. */
+ kthread->result = ERR_PTR(pid);
+ complete(started);
}
- create->result = ERR_PTR(pid);
- complete(done);
+ put_kthread(kthread);
}
}

@@ -272,20 +299,26 @@ struct task_struct *kthread_create_on_node(int (*threadfn)(void *data),
const char namefmt[],
...)
{
- DECLARE_COMPLETION_ONSTACK(done);
+ DECLARE_COMPLETION_ONSTACK(started);
struct task_struct *task;
- struct kthread_create_info *create = kmalloc(sizeof(*create),
- GFP_KERNEL);
+ struct kthread *kthread;

- if (!create)
+ kthread = kmalloc(sizeof(*kthread), GFP_KERNEL);
+ if (!kthread)
return ERR_PTR(-ENOMEM);
- create->threadfn = threadfn;
- create->data = data;
- create->node = node;
- create->done = &done;
+ /* One ref for us and one ref for a new kernel thread. */
+ atomic_set(&kthread->refs, 2);
+ kthread->flags = 0;
+ kthread->cpu = 0;
+ kthread->threadfn = threadfn;
+ kthread->data = data;
+ kthread->node = node;
+ kthread->started = &started;
+ init_completion(&kthread->exited);
+ init_completion(&kthread->parked);

spin_lock(&kthread_create_lock);
- list_add_tail(&create->list, &kthread_create_list);
+ list_add_tail(&kthread->list, &kthread_create_list);
spin_unlock(&kthread_create_lock);

wake_up_process(kthreadd_task);
@@ -294,21 +327,23 @@ struct task_struct *kthread_create_on_node(int (*threadfn)(void *data),
* the OOM killer while kthreadd is trying to allocate memory for
* new kernel thread.
*/
- if (unlikely(wait_for_completion_killable(&done))) {
+ if (unlikely(wait_for_completion_killable(&started))) {
/*
* If I was SIGKILLed before kthreadd (or new kernel thread)
- * calls complete(), leave the cleanup of this structure to
- * that thread.
+ * calls complete(), put a ref and return an error.
*/
- if (xchg(&create->done, NULL))
+ if (xchg(&kthread->started, NULL)) {
+ put_kthread(kthread);
+
return ERR_PTR(-EINTR);
+ }
/*
* kthreadd (or new kernel thread) will call complete()
* shortly.
*/
- wait_for_completion(&done);
+ wait_for_completion(&started);
}
- task = create->result;
+ task = kthread->result;
if (!IS_ERR(task)) {
static const struct sched_param param = { .sched_priority = 0 };
va_list args;
@@ -323,7 +358,8 @@ struct task_struct *kthread_create_on_node(int (*threadfn)(void *data),
sched_setscheduler_nocheck(task, SCHED_NORMAL, &param);
set_cpus_allowed_ptr(task, cpu_all_mask);
}
- kfree(create);
+ put_kthread(kthread);
+
return task;
}
EXPORT_SYMBOL(kthread_create_on_node);
@@ -423,11 +459,11 @@ static void __kthread_unpark(struct task_struct *k, struct kthread *kthread)
*/
void kthread_unpark(struct task_struct *k)
{
- struct kthread *kthread = to_live_kthread(k);
+ struct kthread *kthread = to_live_kthread_and_get(k);

if (kthread) {
__kthread_unpark(k, kthread);
- put_task_stack(k);
+ put_kthread(kthread);
}
}
EXPORT_SYMBOL_GPL(kthread_unpark);
@@ -446,7 +482,7 @@ EXPORT_SYMBOL_GPL(kthread_unpark);
*/
int kthread_park(struct task_struct *k)
{
- struct kthread *kthread = to_live_kthread(k);
+ struct kthread *kthread = to_live_kthread_and_get(k);
int ret = -ENOSYS;

if (kthread) {
@@ -457,7 +493,7 @@ int kthread_park(struct task_struct *k)
wait_for_completion(&kthread->parked);
}
}
- put_task_stack(k);
+ put_kthread(kthread);
ret = 0;
}
return ret;
@@ -487,13 +523,13 @@ int kthread_stop(struct task_struct *k)
trace_sched_kthread_stop(k);

get_task_struct(k);
- kthread = to_live_kthread(k);
+ kthread = to_live_kthread_and_get(k);
if (kthread) {
set_bit(KTHREAD_SHOULD_STOP, &kthread->flags);
__kthread_unpark(k, kthread);
wake_up_process(k);
wait_for_completion(&kthread->exited);
- put_task_stack(k);
+ put_kthread(kthread);
}
ret = k->exit_code;
put_task_struct(k);
@@ -523,14 +559,14 @@ int kthreadd(void *unused)

spin_lock(&kthread_create_lock);
while (!list_empty(&kthread_create_list)) {
- struct kthread_create_info *create;
+ struct kthread *kthread;

- create = list_entry(kthread_create_list.next,
- struct kthread_create_info, list);
- list_del_init(&create->list);
+ kthread = list_entry(kthread_create_list.next,
+ struct kthread, list);
+ list_del_init(&kthread->list);
spin_unlock(&kthread_create_lock);

- create_kthread(create);
+ create_kthread(kthread);

spin_lock(&kthread_create_lock);
}
--
2.9.3