Re: propagating vmgenid outward and upward

From: Michael S. Tsirkin
Date: Wed Mar 02 2022 - 11:29:20 EST


On Wed, Mar 02, 2022 at 04:36:49PM +0100, Jason A. Donenfeld wrote:
> Hi Michael,
>
> On Wed, Mar 02, 2022 at 10:20:25AM -0500, Michael S. Tsirkin wrote:
> > So writing some code:
> >
> > 1:
> > put plaintext in a buffer
> > put a key in a buffer
> > put the nonce for that encryption in a buffer
> >
> > if vm gen id != stored vm gen id
> > stored vm gen id = vm gen id
> > goto 1
> >
> > I think this is race free, but I don't see why does it matter whether we
> > read gen id atomically or not.
>
> Because that 16 byte read of vmgenid is not atomic. Let's say you read
> the first 8 bytes, and then the VM is forked. In the forked VM, the next
> 8 bytes are the same as last time, but the first 8 bytes, which you
> already read, have changed. In that case, your != becomes a ==, and the
> test fails.
>
> This is one of those fundamental things of "unique ID" vs "generation
> counter word".
>
> Anyway, per your request in your last email, I wrote some code for this,
> which may or may not be totally broken, and only works on 64-bit x86,
> which is really the best possible case in terms of performance. And even
> so, it's not great.
>
> Jason
>
> --------8<------------------------
>
> diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c
> index 720952b92e78..250b8973007d 100644
> --- a/drivers/net/wireguard/noise.c
> +++ b/drivers/net/wireguard/noise.c
> @@ -106,6 +106,7 @@ static struct noise_keypair *keypair_create(struct wg_peer *peer)
> keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
> keypair->entry.peer = peer;
> kref_init(&keypair->refcount);
> + keypair->vmgenid = vmgenid_read_atomic();
> return keypair;
> }
>
> diff --git a/drivers/net/wireguard/noise.h b/drivers/net/wireguard/noise.h
> index c527253dba80..0add240a14a0 100644
> --- a/drivers/net/wireguard/noise.h
> +++ b/drivers/net/wireguard/noise.h
> @@ -27,10 +27,13 @@ struct noise_symmetric_key {
> bool is_valid;
> };
>
> +extern __uint128_t vmgenid_read_atomic(void);
> +
> struct noise_keypair {
> struct index_hashtable_entry entry;
> struct noise_symmetric_key sending;
> atomic64_t sending_counter;
> + __uint128_t vmgenid;
> struct noise_symmetric_key receiving;
> struct noise_replay_counter receiving_counter;
> __le32 remote_index;
> diff --git a/drivers/net/wireguard/send.c b/drivers/net/wireguard/send.c
> index 5368f7c35b4b..40d016be59e3 100644
> --- a/drivers/net/wireguard/send.c
> +++ b/drivers/net/wireguard/send.c
> @@ -381,6 +381,9 @@ void wg_packet_send_staged_packets(struct wg_peer *peer)
> goto out_invalid;
> }
>
> + if (keypair->vmgenid != vmgenid_read_atomic())
> + goto out_invalid;
> +
> packets.prev->next = NULL;
> wg_peer_get(keypair->entry.peer);
> PACKET_CB(packets.next)->keypair = keypair;

I don't think we care about an atomic read here. All data is in buffer
by this point, if it did not fork before that then we are ok, even
if it forks during the read.

We probably do need a memory barrier to make sure all writes complete
before the read of vmgenid, I'm not sure which kind - I think hypervisor
can be trusted to do a full CPU barrier on fork so probably just a
compiler barrier.

> diff --git a/drivers/virt/vmgenid.c b/drivers/virt/vmgenid.c
> index 0ae1a39f2e28..c122fae1d494 100644
> --- a/drivers/virt/vmgenid.c
> +++ b/drivers/virt/vmgenid.c
> @@ -21,6 +21,21 @@ struct vmgenid_state {
> u8 this_id[VMGENID_SIZE];
> };
>
> +static __uint128_t *val;
> +
> +__uint128_t vmgenid_read_atomic(void)
> +{
> + __uint128_t ret = 0;
> + if (!val)
> + return 0;
> + asm volatile("lock cmpxchg16b %1"
> + : "+A"(ret)
> + : "m"(*val), "b"(0), "c"(0)
> + : "cc");
> + return ret;
> +}
> +EXPORT_SYMBOL(vmgenid_read_atomic);
> +
> static int vmgenid_add(struct acpi_device *device)
> {
> struct acpi_buffer parsed = { ACPI_ALLOCATE_BUFFER };
> @@ -50,6 +65,7 @@ static int vmgenid_add(struct acpi_device *device)
> phys_addr = (obj->package.elements[0].integer.value << 0) |
> (obj->package.elements[1].integer.value << 32);
> state->next_id = devm_memremap(&device->dev, phys_addr, VMGENID_SIZE, MEMREMAP_WB);
> + val = (__uint128_t *)state->next_id;
> if (IS_ERR(state->next_id)) {
> ret = PTR_ERR(state->next_id);
> goto out;