[PATCH 4.11 147/150] audit: fix the RCU locking for the auditd_connection structure

From: Greg Kroah-Hartman
Date: Mon Jun 12 2017 - 11:33:03 EST


4.11-stable review patch. If anyone has any objections, please let me know.

------------------

From: Paul Moore <paul@xxxxxxxxxxxxxx>

commit 48d0e023af9799cd7220335baf8e3ba61eeafbeb upstream.

Cong Wang correctly pointed out that the RCU read locking of the
auditd_connection struct was wrong, this patch correct this by
adopting a more traditional, and correct RCU locking model.

This patch is heavily based on an earlier prototype by Cong Wang.

Reported-by: Cong Wang <xiyou.wangcong@xxxxxxxxx>
Signed-off-by: Cong Wang <xiyou.wangcong@xxxxxxxxx>
Signed-off-by: Paul Moore <paul@xxxxxxxxxxxxxx>
Signed-off-by: Greg Kroah-Hartman <gregkh@xxxxxxxxxxxxxxxxxxx>


---
kernel/audit.c | 167 +++++++++++++++++++++++++++++++++++++++------------------
1 file changed, 115 insertions(+), 52 deletions(-)

--- a/kernel/audit.c
+++ b/kernel/audit.c
@@ -110,18 +110,19 @@ struct audit_net {
* @pid: auditd PID
* @portid: netlink portid
* @net: the associated network namespace
- * @lock: spinlock to protect write access
+ * @rcu: RCU head
*
* Description:
* This struct is RCU protected; you must either hold the RCU lock for reading
- * or the included spinlock for writing.
+ * or the associated spinlock for writing.
*/
static struct auditd_connection {
int pid;
u32 portid;
struct net *net;
- spinlock_t lock;
-} auditd_conn;
+ struct rcu_head rcu;
+} *auditd_conn = NULL;
+static DEFINE_SPINLOCK(auditd_conn_lock);

/* If audit_rate_limit is non-zero, limit the rate of sending audit records
* to that number per second. This prevents DoS attacks, but results in
@@ -223,15 +224,39 @@ struct audit_reply {
int auditd_test_task(const struct task_struct *task)
{
int rc;
+ struct auditd_connection *ac;

rcu_read_lock();
- rc = (auditd_conn.pid && task->tgid == auditd_conn.pid ? 1 : 0);
+ ac = rcu_dereference(auditd_conn);
+ rc = (ac && ac->pid == task->tgid ? 1 : 0);
rcu_read_unlock();

return rc;
}

/**
+ * auditd_pid_vnr - Return the auditd PID relative to the namespace
+ *
+ * Description:
+ * Returns the PID in relation to the namespace, 0 on failure.
+ */
+static pid_t auditd_pid_vnr(void)
+{
+ pid_t pid;
+ const struct auditd_connection *ac;
+
+ rcu_read_lock();
+ ac = rcu_dereference(auditd_conn);
+ if (!ac)
+ pid = 0;
+ else
+ pid = ac->pid;
+ rcu_read_unlock();
+
+ return pid;
+}
+
+/**
* audit_get_sk - Return the audit socket for the given network namespace
* @net: the destination network namespace
*
@@ -427,6 +452,23 @@ static int audit_set_failure(u32 state)
}

/**
+ * auditd_conn_free - RCU helper to release an auditd connection struct
+ * @rcu: RCU head
+ *
+ * Description:
+ * Drop any references inside the auditd connection tracking struct and free
+ * the memory.
+ */
+static void auditd_conn_free(struct rcu_head *rcu)
+{
+ struct auditd_connection *ac;
+
+ ac = container_of(rcu, struct auditd_connection, rcu);
+ put_net(ac->net);
+ kfree(ac);
+}
+
+/**
* auditd_set - Set/Reset the auditd connection state
* @pid: auditd PID
* @portid: auditd netlink portid
@@ -434,22 +476,33 @@ static int audit_set_failure(u32 state)
*
* Description:
* This function will obtain and drop network namespace references as
- * necessary.
+ * necessary. Returns zero on success, negative values on failure.
*/
-static void auditd_set(int pid, u32 portid, struct net *net)
+static int auditd_set(int pid, u32 portid, struct net *net)
{
unsigned long flags;
+ struct auditd_connection *ac_old, *ac_new;

- spin_lock_irqsave(&auditd_conn.lock, flags);
- auditd_conn.pid = pid;
- auditd_conn.portid = portid;
- if (auditd_conn.net)
- put_net(auditd_conn.net);
- if (net)
- auditd_conn.net = get_net(net);
- else
- auditd_conn.net = NULL;
- spin_unlock_irqrestore(&auditd_conn.lock, flags);
+ if (!pid || !net)
+ return -EINVAL;
+
+ ac_new = kzalloc(sizeof(*ac_new), GFP_KERNEL);
+ if (!ac_new)
+ return -ENOMEM;
+ ac_new->pid = pid;
+ ac_new->portid = portid;
+ ac_new->net = get_net(net);
+
+ spin_lock_irqsave(&auditd_conn_lock, flags);
+ ac_old = rcu_dereference_protected(auditd_conn,
+ lockdep_is_held(&auditd_conn_lock));
+ rcu_assign_pointer(auditd_conn, ac_new);
+ spin_unlock_irqrestore(&auditd_conn_lock, flags);
+
+ if (ac_old)
+ call_rcu(&ac_old->rcu, auditd_conn_free);
+
+ return 0;
}

/**
@@ -544,13 +597,19 @@ static void kauditd_retry_skb(struct sk_
*/
static void auditd_reset(void)
{
+ unsigned long flags;
struct sk_buff *skb;
+ struct auditd_connection *ac_old;

/* if it isn't already broken, break the connection */
- rcu_read_lock();
- if (auditd_conn.pid)
- auditd_set(0, 0, NULL);
- rcu_read_unlock();
+ spin_lock_irqsave(&auditd_conn_lock, flags);
+ ac_old = rcu_dereference_protected(auditd_conn,
+ lockdep_is_held(&auditd_conn_lock));
+ rcu_assign_pointer(auditd_conn, NULL);
+ spin_unlock_irqrestore(&auditd_conn_lock, flags);
+
+ if (ac_old)
+ call_rcu(&ac_old->rcu, auditd_conn_free);

/* flush all of the main and retry queues to the hold queue */
while ((skb = skb_dequeue(&audit_retry_queue)))
@@ -576,6 +635,7 @@ static int auditd_send_unicast_skb(struc
u32 portid;
struct net *net;
struct sock *sk;
+ struct auditd_connection *ac;

/* NOTE: we can't call netlink_unicast while in the RCU section so
* take a reference to the network namespace and grab local
@@ -585,15 +645,15 @@ static int auditd_send_unicast_skb(struc
* section netlink_unicast() should safely return an error */

rcu_read_lock();
- if (!auditd_conn.pid) {
+ ac = rcu_dereference(auditd_conn);
+ if (!ac) {
rcu_read_unlock();
rc = -ECONNREFUSED;
goto err;
}
- net = auditd_conn.net;
- get_net(net);
+ net = get_net(ac->net);
sk = audit_get_sk(net);
- portid = auditd_conn.portid;
+ portid = ac->portid;
rcu_read_unlock();

rc = netlink_unicast(sk, skb, portid, 0);
@@ -728,6 +788,7 @@ static int kauditd_thread(void *dummy)
u32 portid = 0;
struct net *net = NULL;
struct sock *sk = NULL;
+ struct auditd_connection *ac;

#define UNICAST_RETRIES 5

@@ -735,14 +796,14 @@ static int kauditd_thread(void *dummy)
while (!kthread_should_stop()) {
/* NOTE: see the lock comments in auditd_send_unicast_skb() */
rcu_read_lock();
- if (!auditd_conn.pid) {
+ ac = rcu_dereference(auditd_conn);
+ if (!ac) {
rcu_read_unlock();
goto main_queue;
}
- net = auditd_conn.net;
- get_net(net);
+ net = get_net(ac->net);
sk = audit_get_sk(net);
- portid = auditd_conn.portid;
+ portid = ac->portid;
rcu_read_unlock();

/* attempt to flush the hold queue */
@@ -1102,9 +1163,7 @@ static int audit_receive_msg(struct sk_b
memset(&s, 0, sizeof(s));
s.enabled = audit_enabled;
s.failure = audit_failure;
- rcu_read_lock();
- s.pid = auditd_conn.pid;
- rcu_read_unlock();
+ s.pid = auditd_pid_vnr();
s.rate_limit = audit_rate_limit;
s.backlog_limit = audit_backlog_limit;
s.lost = atomic_read(&audit_lost);
@@ -1143,38 +1202,44 @@ static int audit_receive_msg(struct sk_b
/* test the auditd connection */
audit_replace(requesting_pid);

- rcu_read_lock();
- auditd_pid = auditd_conn.pid;
+ auditd_pid = auditd_pid_vnr();
/* only the current auditd can unregister itself */
if ((!new_pid) && (requesting_pid != auditd_pid)) {
- rcu_read_unlock();
audit_log_config_change("audit_pid", new_pid,
auditd_pid, 0);
return -EACCES;
}
/* replacing a healthy auditd is not allowed */
if (auditd_pid && new_pid) {
- rcu_read_unlock();
audit_log_config_change("audit_pid", new_pid,
auditd_pid, 0);
return -EEXIST;
}
- rcu_read_unlock();
-
- if (audit_enabled != AUDIT_OFF)
- audit_log_config_change("audit_pid", new_pid,
- auditd_pid, 1);

if (new_pid) {
/* register a new auditd connection */
- auditd_set(new_pid,
- NETLINK_CB(skb).portid,
- sock_net(NETLINK_CB(skb).sk));
+ err = auditd_set(new_pid,
+ NETLINK_CB(skb).portid,
+ sock_net(NETLINK_CB(skb).sk));
+ if (audit_enabled != AUDIT_OFF)
+ audit_log_config_change("audit_pid",
+ new_pid,
+ auditd_pid,
+ err ? 0 : 1);
+ if (err)
+ return err;
+
/* try to process any backlog */
wake_up_interruptible(&kauditd_wait);
- } else
+ } else {
+ if (audit_enabled != AUDIT_OFF)
+ audit_log_config_change("audit_pid",
+ new_pid,
+ auditd_pid, 1);
+
/* unregister the auditd connection */
auditd_reset();
+ }
}
if (s.mask & AUDIT_STATUS_RATE_LIMIT) {
err = audit_set_rate_limit(s.rate_limit);
@@ -1447,10 +1512,11 @@ static void __net_exit audit_net_exit(st
{
struct audit_net *aunet = net_generic(net, audit_net_id);

- rcu_read_lock();
- if (net == auditd_conn.net)
- auditd_reset();
- rcu_read_unlock();
+ /* NOTE: you would think that we would want to check the auditd
+ * connection and potentially reset it here if it lives in this
+ * namespace, but since the auditd connection tracking struct holds a
+ * reference to this namespace (see auditd_set()) we are only ever
+ * going to get here after that connection has been released */

netlink_kernel_release(aunet->sk);
}
@@ -1470,9 +1536,6 @@ static int __init audit_init(void)
if (audit_initialized == AUDIT_DISABLED)
return 0;

- memset(&auditd_conn, 0, sizeof(auditd_conn));
- spin_lock_init(&auditd_conn.lock);
-
skb_queue_head_init(&audit_queue);
skb_queue_head_init(&audit_retry_queue);
skb_queue_head_init(&audit_hold_queue);