[PATCH] virtio console: Keep a local copy of the control structure

From: Alexander Shishkin
Date: Thu Jan 19 2023 - 14:59:02 EST


When handling control messages, instead of peeking at the device memory
to obtain bits of the control structure, take a snapshot of it once and
use it instead, to prevent it from changing under us. This avoids races
between port id validation and control event decoding, which can lead
to, for example, a NULL dereference in port removal of a nonexistent
port.

The control structure is small enough (8 bytes) that it can be cached
directly on the stack.

Signed-off-by: Alexander Shishkin <alexander.shishkin@xxxxxxxxxxxxxxx>
Cc: Greg Kroah-Hartman <gregkh@xxxxxxxxxxxxxxxxxxx>
Cc: Arnd Bergmann <arnd@xxxxxxxx>
Cc: Amit Shah <amit@xxxxxxxxxx>
---
drivers/char/virtio_console.c | 29 +++++++++++++++--------------
1 file changed, 15 insertions(+), 14 deletions(-)

diff --git a/drivers/char/virtio_console.c b/drivers/char/virtio_console.c
index 6a821118d553..42be0991a72f 100644
--- a/drivers/char/virtio_console.c
+++ b/drivers/char/virtio_console.c
@@ -1559,23 +1559,24 @@ static void handle_control_message(struct virtio_device *vdev,
struct ports_device *portdev,
struct port_buffer *buf)
{
- struct virtio_console_control *cpkt;
+ struct virtio_console_control cpkt;
struct port *port;
size_t name_size;
int err;

- cpkt = (struct virtio_console_control *)(buf->buf + buf->offset);
+ /* Keep a local copy of the control structure */
+ memcpy(&cpkt, buf->buf + buf->offset, sizeof(cpkt));

- port = find_port_by_id(portdev, virtio32_to_cpu(vdev, cpkt->id));
+ port = find_port_by_id(portdev, virtio32_to_cpu(vdev, cpkt.id));
if (!port &&
- cpkt->event != cpu_to_virtio16(vdev, VIRTIO_CONSOLE_PORT_ADD)) {
+ cpkt.event != cpu_to_virtio16(vdev, VIRTIO_CONSOLE_PORT_ADD)) {
/* No valid header at start of buffer. Drop it. */
dev_dbg(&portdev->vdev->dev,
- "Invalid index %u in control packet\n", cpkt->id);
+ "Invalid index %u in control packet\n", cpkt.id);
return;
}

- switch (virtio16_to_cpu(vdev, cpkt->event)) {
+ switch (virtio16_to_cpu(vdev, cpkt.event)) {
case VIRTIO_CONSOLE_PORT_ADD:
if (port) {
dev_dbg(&portdev->vdev->dev,
@@ -1583,21 +1584,21 @@ static void handle_control_message(struct virtio_device *vdev,
send_control_msg(port, VIRTIO_CONSOLE_PORT_READY, 1);
break;
}
- if (virtio32_to_cpu(vdev, cpkt->id) >=
+ if (virtio32_to_cpu(vdev, cpkt.id) >=
portdev->max_nr_ports) {
dev_warn(&portdev->vdev->dev,
"Request for adding port with "
"out-of-bound id %u, max. supported id: %u\n",
- cpkt->id, portdev->max_nr_ports - 1);
+ cpkt.id, portdev->max_nr_ports - 1);
break;
}
- add_port(portdev, virtio32_to_cpu(vdev, cpkt->id));
+ add_port(portdev, virtio32_to_cpu(vdev, cpkt.id));
break;
case VIRTIO_CONSOLE_PORT_REMOVE:
unplug_port(port);
break;
case VIRTIO_CONSOLE_CONSOLE_PORT:
- if (!cpkt->value)
+ if (!cpkt.value)
break;
if (is_console_port(port))
break;
@@ -1618,7 +1619,7 @@ static void handle_control_message(struct virtio_device *vdev,
if (!is_console_port(port))
break;

- memcpy(&size, buf->buf + buf->offset + sizeof(*cpkt),
+ memcpy(&size, buf->buf + buf->offset + sizeof(cpkt),
sizeof(size));
set_console_size(port, size.rows, size.cols);

@@ -1627,7 +1628,7 @@ static void handle_control_message(struct virtio_device *vdev,
break;
}
case VIRTIO_CONSOLE_PORT_OPEN:
- port->host_connected = virtio16_to_cpu(vdev, cpkt->value);
+ port->host_connected = virtio16_to_cpu(vdev, cpkt.value);
wake_up_interruptible(&port->waitqueue);
/*
* If the host port got closed and the host had any
@@ -1658,7 +1659,7 @@ static void handle_control_message(struct virtio_device *vdev,
* Skip the size of the header and the cpkt to get the size
* of the name that was sent
*/
- name_size = buf->len - buf->offset - sizeof(*cpkt) + 1;
+ name_size = buf->len - buf->offset - sizeof(cpkt) + 1;

port->name = kmalloc(name_size, GFP_KERNEL);
if (!port->name) {
@@ -1666,7 +1667,7 @@ static void handle_control_message(struct virtio_device *vdev,
"Not enough space to store port name\n");
break;
}
- strncpy(port->name, buf->buf + buf->offset + sizeof(*cpkt),
+ strncpy(port->name, buf->buf + buf->offset + sizeof(cpkt),
name_size - 1);
port->name[name_size - 1] = 0;

--
2.39.0