[PATCH RFC net-next v4 7/8] vsock: Add lockless sendmsg() support

From: Bobby Eshleman
Date: Fri Jun 09 2023 - 20:59:32 EST


Because the dgram sendmsg() path for AF_VSOCK acquires the socket lock
it does not scale when many senders share a socket.

Prior to this patch the socket lock is used to protect both reads and
writes to the local_addr, remote_addr, transport, and buffer size
variables of a vsock socket. What follows are the new protection schemes
for these fields that ensure a race-free and usually lock-free
multi-sender sendmsg() path for vsock dgrams.

- local_addr
local_addr changes as a result of binding a socket. The write path
for local_addr is bind() and various vsock_auto_bind() call sites.
After a socket has been bound via vsock_auto_bind() or bind(), subsequent
calls to bind()/vsock_auto_bind() do not write to local_addr again. bind()
rejects the user request and vsock_auto_bind() early exits.
Therefore, the local addr can not change while a parallel thread is
in sendmsg() and lock-free reads of local addr in sendmsg() are safe.
Change: only acquire lock for auto-binding as-needed in sendmsg().

- buffer size variables
Not used by dgram, so they do not need protection. No change.

- remote_addr and transport
Because a remote_addr update may result in a changed transport, but we
would like to be able to read these two fields lock-free but coherently
in the vsock send path, this patch packages these two fields into a new
struct vsock_remote_info that is referenced by an RCU-protected pointer.

Writes are synchronized as usual by the socket lock. Reads only take
place in RCU read-side critical sections. When remote_addr or transport
is updated, a new remote info is allocated. Old readers still see the
old coherent remote_addr/transport pair, and new readers will refer to
the new coherent. The coherency between remote_addr and transport
previously provided by the socket lock alone is now also preserved by
RCU, except with the highly-scalable lock-free read-side.

Helpers are introduced for accessing and updating the new pointer.

The new structure is contains an rcu_head so that kfree_rcu() can be
used. This removes the need of writers to use synchronize_rcu() after
freeing old structures which is simply more efficient and reduces code
churn where remote_addr/transport are already being updated inside RCU
read-side sections.

Only virtio has been tested, but updates were necessary to the VMCI and
hyperv code. Unfortunately the author does not have access to
VMCI/hyperv systems so those changes are untested.

Perf Tests (results from patch v2)
vCPUS: 16
Threads: 16
Payload: 4KB
Test Runs: 5
Type: SOCK_DGRAM

Before: 245.2 MB/s
After: 509.2 MB/s (+107%)

Notably, on the same test system, vsock dgram even outperforms
multi-threaded UDP over virtio-net with vhost and MQ support enabled.

Throughput metrics for single-threaded SOCK_DGRAM and
single/multi-threaded SOCK_STREAM showed no statistically signficant
throughput changes (lowest p-value reaching 0.27), with the range of the
mean difference ranging between -5% to +1%.

Signed-off-by: Bobby Eshleman <bobby.eshleman@xxxxxxxxxxxxx>
---
drivers/vhost/vsock.c | 12 +-
include/linux/virtio_vsock.h | 3 +-
include/net/af_vsock.h | 38 ++-
net/vmw_vsock/af_vsock.c | 399 ++++++++++++++++++++++++++------
net/vmw_vsock/diag.c | 10 +-
net/vmw_vsock/hyperv_transport.c | 27 ++-
net/vmw_vsock/virtio_transport_common.c | 34 ++-
net/vmw_vsock/vmci_transport.c | 84 +++++--
net/vmw_vsock/vsock_bpf.c | 10 +-
9 files changed, 492 insertions(+), 125 deletions(-)

diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 159c1a22c1a8..b027a780d333 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -297,13 +297,17 @@ static int
vhost_transport_cancel_pkt(struct vsock_sock *vsk)
{
struct vhost_vsock *vsock;
+ unsigned int cid;
int cnt = 0;
int ret = -ENODEV;

rcu_read_lock();
+ ret = vsock_remote_addr_cid(vsk, &cid);
+ if (ret < 0)
+ goto out;

/* Find the vhost_vsock according to guest context id */
- vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
+ vsock = vhost_vsock_get(cid);
if (!vsock)
goto out;

@@ -706,6 +710,10 @@ static void vhost_vsock_flush(struct vhost_vsock *vsock)
static void vhost_vsock_reset_orphans(struct sock *sk)
{
struct vsock_sock *vsk = vsock_sk(sk);
+ unsigned int cid;
+
+ if (vsock_remote_addr_cid(vsk, &cid) < 0)
+ return;

/* vmci_transport.c doesn't take sk_lock here either. At least we're
* under vsock_table_lock so the sock cannot disappear while we're
@@ -713,7 +721,7 @@ static void vhost_vsock_reset_orphans(struct sock *sk)
*/

/* If the peer is still valid, no need to reset connection */
- if (vhost_vsock_get(vsk->remote_addr.svm_cid))
+ if (vhost_vsock_get(cid))
return;

/* If the close timeout is pending, let it expire. This avoids races
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index 237ca87a2ecd..97656e83606f 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -231,7 +231,8 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len);
int
-virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
+virtio_transport_dgram_enqueue(const struct vsock_transport *transport,
+ struct vsock_sock *vsk,
struct sockaddr_vm *remote_addr,
struct msghdr *msg,
size_t len);
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index c115e655b4f5..928b09fbc64b 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -25,12 +25,17 @@ extern spinlock_t vsock_table_lock;
#define vsock_sk(__sk) ((struct vsock_sock *)__sk)
#define sk_vsock(__vsk) (&(__vsk)->sk)

+struct vsock_remote_info {
+ struct sockaddr_vm addr;
+ struct rcu_head rcu;
+ const struct vsock_transport *transport;
+};
+
struct vsock_sock {
/* sk must be the first member. */
struct sock sk;
- const struct vsock_transport *transport;
struct sockaddr_vm local_addr;
- struct sockaddr_vm remote_addr;
+ struct vsock_remote_info __rcu *remote_info;
/* Links for the global tables of bound and connected sockets. */
struct list_head bound_table;
struct list_head connected_table;
@@ -120,8 +125,8 @@ struct vsock_transport {

/* DGRAM. */
int (*dgram_bind)(struct vsock_sock *, struct sockaddr_vm *);
- int (*dgram_enqueue)(struct vsock_sock *, struct sockaddr_vm *,
- struct msghdr *, size_t len);
+ int (*dgram_enqueue)(const struct vsock_transport *, struct vsock_sock *,
+ struct sockaddr_vm *, struct msghdr *, size_t len);
bool (*dgram_allow)(u32 cid, u32 port);
int (*dgram_get_cid)(struct sk_buff *skb, unsigned int *cid);
int (*dgram_get_port)(struct sk_buff *skb, unsigned int *port);
@@ -196,6 +201,16 @@ void vsock_core_unregister(const struct vsock_transport *t);
/* The transport may downcast this to access transport-specific functions */
const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk);

+static inline struct vsock_remote_info *
+vsock_core_get_remote_info(struct vsock_sock *vsk)
+{
+ /* vsk->remote_info may be accessed if the rcu read lock is held OR the
+ * socket lock is held
+ */
+ return rcu_dereference_check(vsk->remote_info,
+ lockdep_sock_is_held(sk_vsock(vsk)));
+}
+
/**** UTILS ****/

/* vsock_table_lock must be held */
@@ -214,7 +229,7 @@ void vsock_release_pending(struct sock *pending);
void vsock_add_pending(struct sock *listener, struct sock *pending);
void vsock_remove_pending(struct sock *listener, struct sock *pending);
void vsock_enqueue_accept(struct sock *listener, struct sock *connected);
-void vsock_insert_connected(struct vsock_sock *vsk);
+int vsock_insert_connected(struct vsock_sock *vsk);
void vsock_remove_bound(struct vsock_sock *vsk);
void vsock_remove_connected(struct vsock_sock *vsk);
struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
@@ -223,7 +238,8 @@ struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
void vsock_remove_sock(struct vsock_sock *vsk);
void vsock_for_each_connected_socket(struct vsock_transport *transport,
void (*fn)(struct sock *sk));
-int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk);
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk,
+ struct sockaddr_vm *remote_addr);
bool vsock_find_cid(unsigned int cid);
struct sock *vsock_find_bound_dgram_socket(struct sockaddr_vm *addr);

@@ -253,4 +269,14 @@ static inline void __init vsock_bpf_build_proto(void)
{}
#endif

+/* RCU-protected remote addr helpers */
+int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid);
+int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port);
+int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid,
+ unsigned int *port);
+int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *dest);
+bool vsock_remote_addr_bound(struct vsock_sock *vsk);
+bool vsock_remote_addr_equals(struct vsock_sock *vsk, struct sockaddr_vm *other);
+int vsock_remote_addr_update_cid_port(struct vsock_sock *vsk, u32 cid, u32 port);
+
#endif /* __AF_VSOCK_H__ */
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index b0b18e7f4299..9e620d67889b 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -114,7 +114,12 @@
static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr);
static void vsock_sk_destruct(struct sock *sk);
static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
+static bool vsock_use_local_transport(unsigned int remote_cid);
static bool sock_type_connectible(u16 type);
+static const struct vsock_transport *
+vsock_connectible_lookup_transport(unsigned int cid, __u8 flags);
+static const struct vsock_transport *
+vsock_dgram_lookup_transport(unsigned int cid, __u8 flags);

/* Protocol family. */
struct proto vsock_proto = {
@@ -146,6 +151,123 @@ static const struct vsock_transport *transport_local;
static DEFINE_MUTEX(vsock_register_mutex);

/**** UTILS ****/
+bool vsock_remote_addr_bound(struct vsock_sock *vsk)
+{
+ struct vsock_remote_info *remote_info;
+ bool ret;
+
+ rcu_read_lock();
+ remote_info = vsock_core_get_remote_info(vsk);
+ if (!remote_info) {
+ rcu_read_unlock();
+ return false;
+ }
+
+ ret = vsock_addr_bound(&remote_info->addr);
+ rcu_read_unlock();
+
+ return ret;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_bound);
+
+int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *dest)
+{
+ struct vsock_remote_info *remote_info;
+
+ rcu_read_lock();
+ remote_info = vsock_core_get_remote_info(vsk);
+ if (!remote_info) {
+ rcu_read_unlock();
+ return -EINVAL;
+ }
+ memcpy(dest, &remote_info->addr, sizeof(*dest));
+ rcu_read_unlock();
+
+ return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_copy);
+
+int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid)
+{
+ return vsock_remote_addr_cid_port(vsk, cid, NULL);
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_cid);
+
+int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port)
+{
+ return vsock_remote_addr_cid_port(vsk, NULL, port);
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_port);
+
+int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid,
+ unsigned int *port)
+{
+ struct vsock_remote_info *remote_info;
+
+ rcu_read_lock();
+ remote_info = vsock_core_get_remote_info(vsk);
+ if (!remote_info) {
+ rcu_read_unlock();
+ return -EINVAL;
+ }
+
+ if (cid)
+ *cid = remote_info->addr.svm_cid;
+ if (port)
+ *port = remote_info->addr.svm_port;
+
+ rcu_read_unlock();
+ return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_cid_port);
+
+/* The socket lock must be held by the caller */
+static int vsock_set_remote_info(struct vsock_sock *vsk,
+ const struct vsock_transport *transport,
+ struct sockaddr_vm *addr)
+{
+ struct vsock_remote_info *old, *new;
+
+ if (addr || transport) {
+ new = kmalloc(sizeof(*new), GFP_KERNEL);
+ if (!new)
+ return -ENOMEM;
+
+ if (addr)
+ memcpy(&new->addr, addr, sizeof(new->addr));
+
+ if (transport)
+ new->transport = transport;
+ } else {
+ new = NULL;
+ }
+
+ old = rcu_replace_pointer(vsk->remote_info, new,
+ lockdep_sock_is_held(sk_vsock(vsk)));
+ kfree_rcu(old, rcu);
+
+ return 0;
+}
+
+bool vsock_remote_addr_equals(struct vsock_sock *vsk,
+ struct sockaddr_vm *other)
+{
+ struct vsock_remote_info *remote_info;
+ bool equals;
+
+ rcu_read_lock();
+ remote_info = vsock_core_get_remote_info(vsk);
+ if (!remote_info) {
+ rcu_read_unlock();
+ return false;
+ }
+
+ equals = vsock_addr_equals_addr(&remote_info->addr, other);
+ rcu_read_unlock();
+
+ return equals;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_equals);

/* Each bound VSocket is stored in the bind hash table and each connected
* VSocket is stored in the connected hash table.
@@ -283,10 +405,17 @@ static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,

list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
connected_table) {
- if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
+ struct vsock_remote_info *remote_info;
+
+ rcu_read_lock();
+ remote_info = vsock_core_get_remote_info(vsk);
+ if (remote_info &&
+ vsock_addr_equals_addr(src, &remote_info->addr) &&
dst->svm_port == vsk->local_addr.svm_port) {
+ rcu_read_unlock();
return sk_vsock(vsk);
}
+ rcu_read_unlock();
}

return NULL;
@@ -299,14 +428,25 @@ static void vsock_insert_unbound(struct vsock_sock *vsk)
spin_unlock_bh(&vsock_table_lock);
}

-void vsock_insert_connected(struct vsock_sock *vsk)
+int vsock_insert_connected(struct vsock_sock *vsk)
{
- struct list_head *list = vsock_connected_sockets(
- &vsk->remote_addr, &vsk->local_addr);
+ struct vsock_remote_info *remote_info;
+ struct list_head *list;
+
+ rcu_read_lock();
+ remote_info = vsock_core_get_remote_info(vsk);
+ if (!remote_info) {
+ rcu_read_unlock();
+ return -EINVAL;
+ }
+ list = vsock_connected_sockets(&remote_info->addr, &vsk->local_addr);
+ rcu_read_unlock();

spin_lock_bh(&vsock_table_lock);
__vsock_insert_connected(list, vsk);
spin_unlock_bh(&vsock_table_lock);
+
+ return 0;
}
EXPORT_SYMBOL_GPL(vsock_insert_connected);

@@ -388,7 +528,7 @@ void vsock_for_each_connected_socket(struct vsock_transport *transport,
struct vsock_sock *vsk;
list_for_each_entry(vsk, &vsock_connected_table[i],
connected_table) {
- if (vsk->transport != transport)
+ if (vsock_core_get_transport(vsk) != transport)
continue;

fn(sk_vsock(vsk));
@@ -454,12 +594,19 @@ static bool vsock_use_local_transport(unsigned int remote_cid)

static void vsock_deassign_transport(struct vsock_sock *vsk)
{
- if (!vsk->transport)
+ struct vsock_remote_info *remote_info;
+
+ remote_info = rcu_replace_pointer(vsk->remote_info, NULL,
+ lockdep_sock_is_held(sk_vsock(vsk)));
+ if (!remote_info)
return;

- vsk->transport->destruct(vsk);
- module_put(vsk->transport->module);
- vsk->transport = NULL;
+ if (remote_info->transport) {
+ remote_info->transport->destruct(vsk);
+ module_put(remote_info->transport->module);
+ }
+
+ kfree_rcu(remote_info, rcu);
}

static const struct vsock_transport *
@@ -490,26 +637,29 @@ vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
return transport_dgram;
}

-/* Assign a transport to a socket and call the .init transport callback.
+/* Assign a transport and remote addr to a socket and call the .init transport
+ * callback.
*
- * Note: for connection oriented socket this must be called when vsk->remote_addr
- * is set (e.g. during the connect() or when a connection request on a listener
- * socket is received).
- * The vsk->remote_addr is used to decide which transport to use:
+ * The remote_addr is used to decide which transport to use. Both the addr
+ * and transport are updated simultaneously via RCU-protected pointer:
* - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
* g2h is not loaded, will use local transport;
* - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field
* includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport;
* - remote CID > VMADDR_CID_HOST will use host->guest transport;
*/
-int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk,
+ struct sockaddr_vm *remote_addr)
{
const struct vsock_transport *new_transport;
+ struct vsock_remote_info *old_info;
struct sock *sk = sk_vsock(vsk);
- unsigned int remote_cid = vsk->remote_addr.svm_cid;
+ unsigned int remote_cid;
__u8 remote_flags;
int ret;

+ remote_cid = remote_addr->svm_cid;
+
/* If the packet is coming with the source and destination CIDs higher
* than VMADDR_CID_HOST, then a vsock channel where all the packets are
* forwarded to the host should be established. Then the host will
@@ -519,10 +669,10 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
* the connect path the flag can be set by the user space application.
*/
if (psk && vsk->local_addr.svm_cid > VMADDR_CID_HOST &&
- vsk->remote_addr.svm_cid > VMADDR_CID_HOST)
- vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST;
+ remote_cid > VMADDR_CID_HOST)
+ remote_addr->svm_flags |= VMADDR_FLAG_TO_HOST;

- remote_flags = vsk->remote_addr.svm_flags;
+ remote_flags = remote_addr->svm_flags;

switch (sk->sk_type) {
case SOCK_DGRAM:
@@ -538,8 +688,9 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
return -ESOCKTNOSUPPORT;
}

- if (vsk->transport) {
- if (vsk->transport == new_transport)
+ old_info = vsock_core_get_remote_info(vsk);
+ if (old_info && old_info->transport) {
+ if (old_info->transport == new_transport)
return 0;

/* transport->release() must be called with sock lock acquired.
@@ -548,7 +699,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
* function is called on a new socket which is not assigned to
* any transport.
*/
- vsk->transport->release(vsk);
+ old_info->transport->release(vsk);
vsock_deassign_transport(vsk);
}

@@ -566,13 +717,18 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
}
}

- ret = new_transport->init(vsk, psk);
+ ret = vsock_set_remote_info(vsk, new_transport, remote_addr);
if (ret) {
module_put(new_transport->module);
return ret;
}

- vsk->transport = new_transport;
+ ret = new_transport->init(vsk, psk);
+ if (ret) {
+ vsock_set_remote_info(vsk, NULL, NULL);
+ module_put(new_transport->module);
+ return ret;
+ }

return 0;
}
@@ -629,12 +785,14 @@ static bool vsock_is_pending(struct sock *sk)

static int vsock_send_shutdown(struct sock *sk, int mode)
{
+ const struct vsock_transport *transport;
struct vsock_sock *vsk = vsock_sk(sk);

- if (!vsk->transport)
+ transport = vsock_core_get_transport(vsk);
+ if (!transport)
return -ENODEV;

- return vsk->transport->shutdown(vsk, mode);
+ return transport->shutdown(vsk, mode);
}

static void vsock_pending_work(struct work_struct *work)
@@ -757,7 +915,10 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
static int vsock_bind_dgram(struct vsock_sock *vsk,
struct sockaddr_vm *addr)
{
- if (!vsk->transport || !vsk->transport->dgram_bind) {
+ const struct vsock_transport *transport;
+
+ transport = vsock_core_get_transport(vsk);
+ if (!transport || !transport->dgram_bind) {
int retval;

spin_lock_bh(&vsock_dgram_table_lock);
@@ -768,7 +929,7 @@ static int vsock_bind_dgram(struct vsock_sock *vsk,
return retval;
}

- return vsk->transport->dgram_bind(vsk, addr);
+ return transport->dgram_bind(vsk, addr);
}

static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
@@ -817,6 +978,7 @@ static struct sock *__vsock_create(struct net *net,
unsigned short type,
int kern)
{
+ struct vsock_remote_info *remote_info;
struct sock *sk;
struct vsock_sock *psk;
struct vsock_sock *vsk;
@@ -836,7 +998,14 @@ static struct sock *__vsock_create(struct net *net,

vsk = vsock_sk(sk);
vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
- vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
+
+ remote_info = kmalloc(sizeof(*remote_info), GFP_KERNEL);
+ if (!remote_info) {
+ sk_free(sk);
+ return NULL;
+ }
+ vsock_addr_init(&remote_info->addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
+ rcu_assign_pointer(vsk->remote_info, remote_info);

sk->sk_destruct = vsock_sk_destruct;
sk->sk_backlog_rcv = vsock_queue_rcv_skb;
@@ -883,6 +1052,7 @@ static bool sock_type_connectible(u16 type)
static void __vsock_release(struct sock *sk, int level)
{
if (sk) {
+ const struct vsock_transport *transport;
struct sock *pending;
struct vsock_sock *vsk;

@@ -896,8 +1066,9 @@ static void __vsock_release(struct sock *sk, int level)
*/
lock_sock_nested(sk, level);

- if (vsk->transport)
- vsk->transport->release(vsk);
+ transport = vsock_core_get_transport(vsk);
+ if (transport)
+ transport->release(vsk);
else if (sock_type_connectible(sk->sk_type))
vsock_remove_sock(vsk);

@@ -927,8 +1098,6 @@ static void vsock_sk_destruct(struct sock *sk)
* possibly register the address family with the kernel.
*/
vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
- vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
-
put_cred(vsk->owner);
}

@@ -952,16 +1121,22 @@ EXPORT_SYMBOL_GPL(vsock_create_connected);

s64 vsock_stream_has_data(struct vsock_sock *vsk)
{
- return vsk->transport->stream_has_data(vsk);
+ const struct vsock_transport *transport;
+
+ transport = vsock_core_get_transport(vsk);
+
+ return transport->stream_has_data(vsk);
}
EXPORT_SYMBOL_GPL(vsock_stream_has_data);

s64 vsock_connectible_has_data(struct vsock_sock *vsk)
{
+ const struct vsock_transport *transport;
struct sock *sk = sk_vsock(vsk);

+ transport = vsock_core_get_transport(vsk);
if (sk->sk_type == SOCK_SEQPACKET)
- return vsk->transport->seqpacket_has_data(vsk);
+ return transport->seqpacket_has_data(vsk);
else
return vsock_stream_has_data(vsk);
}
@@ -969,7 +1144,10 @@ EXPORT_SYMBOL_GPL(vsock_connectible_has_data);

s64 vsock_stream_has_space(struct vsock_sock *vsk)
{
- return vsk->transport->stream_has_space(vsk);
+ const struct vsock_transport *transport;
+
+ transport = vsock_core_get_transport(vsk);
+ return transport->stream_has_space(vsk);
}
EXPORT_SYMBOL_GPL(vsock_stream_has_space);

@@ -1018,6 +1196,7 @@ static int vsock_getname(struct socket *sock,
struct sock *sk;
struct vsock_sock *vsk;
struct sockaddr_vm *vm_addr;
+ struct vsock_remote_info *rcu_ptr;

sk = sock->sk;
vsk = vsock_sk(sk);
@@ -1030,7 +1209,14 @@ static int vsock_getname(struct socket *sock,
err = -ENOTCONN;
goto out;
}
- vm_addr = &vsk->remote_addr;
+
+ rcu_ptr = vsock_core_get_remote_info(vsk);
+ if (!rcu_ptr) {
+ err = -EINVAL;
+ goto out;
+ }
+
+ vm_addr = &rcu_ptr->addr;
} else {
vm_addr = &vsk->local_addr;
}
@@ -1154,7 +1340,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,

lock_sock(sk);

- transport = vsk->transport;
+ transport = vsock_core_get_transport(vsk);

/* Listening sockets that have connections in their accept
* queue can be read.
@@ -1225,9 +1411,11 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,

static int vsock_read_skb(struct sock *sk, skb_read_actor_t read_actor)
{
+ const struct vsock_transport *transport;
struct vsock_sock *vsk = vsock_sk(sk);

- return vsk->transport->read_skb(vsk, read_actor);
+ transport = vsock_core_get_transport(vsk);
+ return transport->read_skb(vsk, read_actor);
}

static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
@@ -1236,7 +1424,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
int err;
struct sock *sk;
struct vsock_sock *vsk;
- struct sockaddr_vm *remote_addr;
+ struct sockaddr_vm stack_addr, *remote_addr;
const struct vsock_transport *transport;

if (msg->msg_flags & MSG_OOB)
@@ -1247,7 +1435,23 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
sk = sock->sk;
vsk = vsock_sk(sk);

- lock_sock(sk);
+ /* If auto-binding is required, acquire the slock to avoid potential
+ * race conditions. Otherwise, do not acquire the lock.
+ *
+ * We know that the first check of local_addr is racy (indicated by
+ * data_race()). By acquiring the lock and then subsequently checking
+ * again if local_addr is bound (inside vsock_auto_bind()), we can
+ * ensure there are no real data races.
+ *
+ * This technique is borrowed by inet_send_prepare().
+ */
+ if (data_race(!vsock_addr_bound(&vsk->local_addr))) {
+ lock_sock(sk);
+ err = vsock_auto_bind(vsk);
+ release_sock(sk);
+ if (err)
+ return err;
+ }

/* If the provided message contains an address, use that. Otherwise
* fall back on the socket's remote handle (if it has been connected).
@@ -1257,6 +1461,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
&remote_addr) == 0) {
transport = vsock_dgram_lookup_transport(remote_addr->svm_cid,
remote_addr->svm_flags);
+
if (!transport) {
err = -EINVAL;
goto out;
@@ -1287,18 +1492,39 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
goto out;
}

- err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
+ err = transport->dgram_enqueue(transport, vsk, remote_addr, msg, len);
module_put(transport->module);
} else if (sock->state == SS_CONNECTED) {
- remote_addr = &vsk->remote_addr;
- transport = vsk->transport;
+ struct vsock_remote_info *remote_info;
+ const struct vsock_transport *transport;

- err = vsock_auto_bind(vsk);
- if (err)
+ rcu_read_lock();
+ remote_info = vsock_core_get_remote_info(vsk);
+ if (!remote_info) {
+ err = -EINVAL;
+ rcu_read_unlock();
goto out;
+ }

- if (remote_addr->svm_cid == VMADDR_CID_ANY)
+ transport = remote_info->transport;
+ memcpy(&stack_addr, &remote_info->addr, sizeof(stack_addr));
+ rcu_read_unlock();
+
+ remote_addr = &stack_addr;
+
+ if (remote_addr->svm_cid == VMADDR_CID_ANY) {
remote_addr->svm_cid = transport->get_local_cid();
+ lock_sock(sk_vsock(vsk));
+ /* Even though the CID has changed, We do not have to
+ * look up the transport again because the local CID
+ * will never resolve to a different transport.
+ */
+ err = vsock_set_remote_info(vsk, transport, remote_addr);
+ release_sock(sk_vsock(vsk));
+
+ if (err)
+ goto out;
+ }

/* XXX Should connect() or this function ensure remote_addr is
* bound?
@@ -1314,14 +1540,13 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
goto out;
}

- err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
+ err = transport->dgram_enqueue(transport, vsk, &stack_addr, msg, len);
} else {
err = -EINVAL;
goto out;
}

out:
- release_sock(sk);
return err;
}

@@ -1332,18 +1557,22 @@ static int vsock_dgram_connect(struct socket *sock,
struct sock *sk;
struct vsock_sock *vsk;
struct sockaddr_vm *remote_addr;
+ const struct vsock_transport *transport;

sk = sock->sk;
vsk = vsock_sk(sk);

err = vsock_addr_cast(addr, addr_len, &remote_addr);
if (err == -EAFNOSUPPORT && remote_addr->svm_family == AF_UNSPEC) {
+ struct sockaddr_vm addr_any;
+
lock_sock(sk);
- vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY,
- VMADDR_PORT_ANY);
+ vsock_addr_init(&addr_any, VMADDR_CID_ANY, VMADDR_PORT_ANY);
+ err = vsock_set_remote_info(vsk, vsock_core_get_transport(vsk),
+ &addr_any);
sock->state = SS_UNCONNECTED;
release_sock(sk);
- return 0;
+ return err;
} else if (err != 0)
return -EINVAL;

@@ -1353,14 +1582,13 @@ static int vsock_dgram_connect(struct socket *sock,
if (err)
goto out;

- memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
-
- err = vsock_assign_transport(vsk, NULL);
+ err = vsock_assign_transport(vsk, NULL, remote_addr);
if (err)
goto out;

- if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
- remote_addr->svm_port)) {
+ transport = vsock_core_get_transport(vsk);
+ if (!transport->dgram_allow(remote_addr->svm_cid,
+ remote_addr->svm_port)) {
err = -EINVAL;
goto out;
}
@@ -1407,7 +1635,9 @@ int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
if (flags & MSG_OOB || flags & MSG_ERRQUEUE)
return -EOPNOTSUPP;

- transport = vsk->transport;
+ rcu_read_lock();
+ transport = vsock_core_get_transport(vsk);
+ rcu_read_unlock();

/* Retrieve the head sk_buff from the socket's receive queue. */
err = 0;
@@ -1475,7 +1705,7 @@ static const struct proto_ops vsock_dgram_ops = {

static int vsock_transport_cancel_pkt(struct vsock_sock *vsk)
{
- const struct vsock_transport *transport = vsk->transport;
+ const struct vsock_transport *transport = vsock_core_get_transport(vsk);

if (!transport || !transport->cancel_pkt)
return -EOPNOTSUPP;
@@ -1512,6 +1742,7 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr,
struct sock *sk;
struct vsock_sock *vsk;
const struct vsock_transport *transport;
+ struct vsock_remote_info *remote_info;
struct sockaddr_vm *remote_addr;
long timeout;
DEFINE_WAIT(wait);
@@ -1549,14 +1780,20 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr,
}

/* Set the remote address that we are connecting to. */
- memcpy(&vsk->remote_addr, remote_addr,
- sizeof(vsk->remote_addr));
-
- err = vsock_assign_transport(vsk, NULL);
+ err = vsock_assign_transport(vsk, NULL, remote_addr);
if (err)
goto out;

- transport = vsk->transport;
+ rcu_read_lock();
+ remote_info = vsock_core_get_remote_info(vsk);
+ if (!remote_info) {
+ err = -EINVAL;
+ rcu_read_unlock();
+ goto out;
+ }
+
+ transport = remote_info->transport;
+ rcu_read_unlock();

/* The hypervisor and well-known contexts do not have socket
* endpoints.
@@ -1820,7 +2057,7 @@ static int vsock_connectible_setsockopt(struct socket *sock,

lock_sock(sk);

- transport = vsk->transport;
+ transport = vsock_core_get_transport(vsk);

switch (optname) {
case SO_VM_SOCKETS_BUFFER_SIZE:
@@ -1958,7 +2195,7 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,

lock_sock(sk);

- transport = vsk->transport;
+ transport = vsock_core_get_transport(vsk);

/* Callers should not provide a destination with connection oriented
* sockets.
@@ -1981,7 +2218,7 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
goto out;
}

- if (!vsock_addr_bound(&vsk->remote_addr)) {
+ if (!vsock_remote_addr_bound(vsk)) {
err = -EDESTADDRREQ;
goto out;
}
@@ -2102,7 +2339,7 @@ static int vsock_connectible_wait_data(struct sock *sk,

vsk = vsock_sk(sk);
err = 0;
- transport = vsk->transport;
+ transport = vsock_core_get_transport(vsk);

while (1) {
prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE);
@@ -2170,7 +2407,7 @@ static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg,
DEFINE_WAIT(wait);

vsk = vsock_sk(sk);
- transport = vsk->transport;
+ transport = vsock_core_get_transport(vsk);

/* We must not copy less than target bytes into the user's buffer
* before returning successfully, so we wait for the consume queue to
@@ -2246,7 +2483,7 @@ static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
DEFINE_WAIT(wait);

vsk = vsock_sk(sk);
- transport = vsk->transport;
+ transport = vsock_core_get_transport(vsk);

timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);

@@ -2303,7 +2540,7 @@ vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,

lock_sock(sk);

- transport = vsk->transport;
+ transport = vsock_core_get_transport(vsk);

if (!transport || sk->sk_state != TCP_ESTABLISHED) {
/* Recvmsg is supposed to return 0 if a peer performs an
@@ -2370,7 +2607,7 @@ static int vsock_set_rcvlowat(struct sock *sk, int val)
if (val > vsk->buffer_size)
return -EINVAL;

- transport = vsk->transport;
+ transport = vsock_core_get_transport(vsk);

if (transport && transport->set_rcvlowat)
return transport->set_rcvlowat(vsk, val);
@@ -2460,7 +2697,10 @@ static int vsock_create(struct net *net, struct socket *sock,
vsk = vsock_sk(sk);

if (sock->type == SOCK_DGRAM) {
- ret = vsock_assign_transport(vsk, NULL);
+ struct sockaddr_vm remote_addr;
+
+ vsock_addr_init(&remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
+ ret = vsock_assign_transport(vsk, NULL, &remote_addr);
if (ret < 0) {
sock_put(sk);
return ret;
@@ -2582,7 +2822,18 @@ static void __exit vsock_exit(void)

const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
{
- return vsk->transport;
+ const struct vsock_transport *transport;
+ struct vsock_remote_info *remote_info;
+
+ rcu_read_lock();
+ remote_info = vsock_core_get_remote_info(vsk);
+ if (!remote_info) {
+ rcu_read_unlock();
+ return NULL;
+ }
+ transport = remote_info->transport;
+ rcu_read_unlock();
+ return transport;
}
EXPORT_SYMBOL_GPL(vsock_core_get_transport);

diff --git a/net/vmw_vsock/diag.c b/net/vmw_vsock/diag.c
index a2823b1c5e28..f843bae86b32 100644
--- a/net/vmw_vsock/diag.c
+++ b/net/vmw_vsock/diag.c
@@ -15,8 +15,14 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
u32 portid, u32 seq, u32 flags)
{
struct vsock_sock *vsk = vsock_sk(sk);
+ struct sockaddr_vm remote_addr;
struct vsock_diag_msg *rep;
struct nlmsghdr *nlh;
+ int err;
+
+ err = vsock_remote_addr_copy(vsk, &remote_addr);
+ if (err < 0)
+ return err;

nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep),
flags);
@@ -36,8 +42,8 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
rep->vdiag_shutdown = sk->sk_shutdown;
rep->vdiag_src_cid = vsk->local_addr.svm_cid;
rep->vdiag_src_port = vsk->local_addr.svm_port;
- rep->vdiag_dst_cid = vsk->remote_addr.svm_cid;
- rep->vdiag_dst_port = vsk->remote_addr.svm_port;
+ rep->vdiag_dst_cid = remote_addr.svm_cid;
+ rep->vdiag_dst_port = remote_addr.svm_port;
rep->vdiag_ino = sock_i_ino(sk);

sock_diag_save_cookie(sk, rep->vdiag_cookie);
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index c00bc5da769a..84e8c64b3365 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -323,6 +323,8 @@ static void hvs_open_connection(struct vmbus_channel *chan)
goto out;

if (conn_from_host) {
+ struct sockaddr_vm remote_addr;
+
if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog)
goto out;

@@ -336,10 +338,9 @@ static void hvs_open_connection(struct vmbus_channel *chan)
hvs_addr_init(&vnew->local_addr, if_type);

/* Remote peer is always the host */
- vsock_addr_init(&vnew->remote_addr,
- VMADDR_CID_HOST, VMADDR_PORT_ANY);
- vnew->remote_addr.svm_port = get_port_by_srv_id(if_instance);
- ret = vsock_assign_transport(vnew, vsock_sk(sk));
+ vsock_addr_init(&remote_addr, VMADDR_CID_HOST, get_port_by_srv_id(if_instance));
+
+ ret = vsock_assign_transport(vnew, vsock_sk(sk), &remote_addr);
/* Transport assigned (looking at remote_addr) must be the
* same where we received the request.
*/
@@ -459,13 +460,18 @@ static int hvs_connect(struct vsock_sock *vsk)
{
union hvs_service_id vm, host;
struct hvsock *h = vsk->trans;
+ int err;

vm.srv_id = srv_id_template;
vm.svm_port = vsk->local_addr.svm_port;
h->vm_srv_id = vm.srv_id;

host.srv_id = srv_id_template;
- host.svm_port = vsk->remote_addr.svm_port;
+
+ err = vsock_remote_addr_port(vsk, &host.svm_port);
+ if (err < 0)
+ return err;
+
h->host_srv_id = host.srv_id;

return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id);
@@ -566,7 +572,8 @@ static int hvs_dgram_get_length(struct sk_buff *skb, size_t *len)
return -EOPNOTSUPP;
}

-static int hvs_dgram_enqueue(struct vsock_sock *vsk,
+static int hvs_dgram_enqueue(const struct vsock_transport *transport,
+ struct vsock_sock *vsk,
struct sockaddr_vm *remote, struct msghdr *msg,
size_t dgram_len)
{
@@ -866,7 +873,13 @@ static struct vsock_transport hvs_transport = {

static bool hvs_check_transport(struct vsock_sock *vsk)
{
- return vsk->transport == &hvs_transport;
+ bool ret;
+
+ rcu_read_lock();
+ ret = vsock_core_get_transport(vsk) == &hvs_transport;
+ rcu_read_unlock();
+
+ return ret;
}

static int hvs_probe(struct hv_device *hdev,
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index bc9d459723f5..9d090f208648 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -259,8 +259,9 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
src_cid = t_ops->transport.get_local_cid();
src_port = vsk->local_addr.svm_port;
if (!info->remote_cid) {
- dst_cid = vsk->remote_addr.svm_cid;
- dst_port = vsk->remote_addr.svm_port;
+ ret = vsock_remote_addr_cid_port(vsk, &dst_cid, &dst_port);
+ if (ret < 0)
+ return ret;
} else {
dst_cid = info->remote_cid;
dst_port = info->remote_port;
@@ -878,12 +879,14 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
EXPORT_SYMBOL_GPL(virtio_transport_shutdown);

int
-virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
+virtio_transport_dgram_enqueue(const struct vsock_transport *transport,
+ struct vsock_sock *vsk,
struct sockaddr_vm *remote_addr,
struct msghdr *msg,
size_t dgram_len)
{
- const struct virtio_transport *t_ops;
+ const struct virtio_transport *t_ops =
+ (const struct virtio_transport *)transport;
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RW,
.msg = msg,
@@ -897,7 +900,6 @@ virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
if (dgram_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
return -EMSGSIZE;

- t_ops = virtio_transport_get_ops(vsk);
src_cid = t_ops->transport.get_local_cid();
src_port = vsk->local_addr.svm_port;

@@ -1121,7 +1123,11 @@ virtio_transport_recv_connecting(struct sock *sk,
case VIRTIO_VSOCK_OP_RESPONSE:
sk->sk_state = TCP_ESTABLISHED;
sk->sk_socket->state = SS_CONNECTED;
- vsock_insert_connected(vsk);
+ err = vsock_insert_connected(vsk);
+ if (err) {
+ skerr = ECONNRESET;
+ goto destroy;
+ }
sk->sk_state_change(sk);
break;
case VIRTIO_VSOCK_OP_INVALID:
@@ -1323,6 +1329,7 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct vsock_sock *vsk = vsock_sk(sk);
struct vsock_sock *vchild;
+ struct sockaddr_vm child_remote;
struct sock *child;
int ret;

@@ -1351,14 +1358,13 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
vchild = vsock_sk(child);
vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid),
le32_to_cpu(hdr->dst_port));
- vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid),
+ vsock_addr_init(&child_remote, le64_to_cpu(hdr->src_cid),
le32_to_cpu(hdr->src_port));
-
- ret = vsock_assign_transport(vchild, vsk);
+ ret = vsock_assign_transport(vchild, vsk, &child_remote);
/* Transport assigned (looking at remote_addr) must be the same
* where we received the request.
*/
- if (ret || vchild->transport != &t->transport) {
+ if (ret || vsock_core_get_transport(vchild) != &t->transport) {
release_sock(child);
virtio_transport_reset_no_sock(t, skb);
sock_put(child);
@@ -1368,7 +1374,13 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
if (virtio_transport_space_update(child, skb))
child->sk_write_space(child);

- vsock_insert_connected(vchild);
+ ret = vsock_insert_connected(vchild);
+ if (ret) {
+ release_sock(child);
+ virtio_transport_reset_no_sock(t, skb);
+ sock_put(child);
+ return ret;
+ }
vsock_enqueue_accept(sk, child);
virtio_transport_send_response(vchild, skb);

diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index bbc63826bf48..943539857ccb 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -283,18 +283,25 @@ vmci_transport_send_control_pkt(struct sock *sk,
u16 proto,
struct vmci_handle handle)
{
+ struct sockaddr_vm addr_stack;
+ struct sockaddr_vm *remote_addr = &addr_stack;
struct vsock_sock *vsk;
+ int err;

vsk = vsock_sk(sk);

if (!vsock_addr_bound(&vsk->local_addr))
return -EINVAL;

- if (!vsock_addr_bound(&vsk->remote_addr))
+ if (!vsock_remote_addr_bound(vsk))
return -EINVAL;

+ err = vsock_remote_addr_copy(vsk, remote_addr);
+ if (err < 0)
+ return err;
+
return vmci_transport_alloc_send_control_pkt(&vsk->local_addr,
- &vsk->remote_addr,
+ remote_addr,
type, size, mode,
wait, proto, handle);
}
@@ -317,6 +324,7 @@ static int vmci_transport_send_reset(struct sock *sk,
struct sockaddr_vm *dst_ptr;
struct sockaddr_vm dst;
struct vsock_sock *vsk;
+ int err;

if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST)
return 0;
@@ -326,13 +334,16 @@ static int vmci_transport_send_reset(struct sock *sk,
if (!vsock_addr_bound(&vsk->local_addr))
return -EINVAL;

- if (vsock_addr_bound(&vsk->remote_addr)) {
- dst_ptr = &vsk->remote_addr;
+ if (vsock_remote_addr_bound(vsk)) {
+ err = vsock_remote_addr_copy(vsk, &dst);
+ if (err < 0)
+ return err;
} else {
vsock_addr_init(&dst, pkt->dg.src.context,
pkt->src_port);
- dst_ptr = &dst;
}
+ dst_ptr = &dst;
+
return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, dst_ptr,
VMCI_TRANSPORT_PACKET_TYPE_RST,
0, 0, NULL, VSOCK_PROTO_INVALID,
@@ -490,7 +501,7 @@ static struct sock *vmci_transport_get_pending(

list_for_each_entry(vpending, &vlistener->pending_links,
pending_links) {
- if (vsock_addr_equals_addr(&src, &vpending->remote_addr) &&
+ if (vsock_remote_addr_equals(vpending, &src) &&
pkt->dst_port == vpending->local_addr.svm_port) {
pending = sk_vsock(vpending);
sock_hold(pending);
@@ -940,6 +951,7 @@ static void vmci_transport_recv_pkt_work(struct work_struct *work)
static int vmci_transport_recv_listen(struct sock *sk,
struct vmci_transport_packet *pkt)
{
+ struct sockaddr_vm remote_addr;
struct sock *pending;
struct vsock_sock *vpending;
int err;
@@ -1015,10 +1027,10 @@ static int vmci_transport_recv_listen(struct sock *sk,

vsock_addr_init(&vpending->local_addr, pkt->dg.dst.context,
pkt->dst_port);
- vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context,
- pkt->src_port);

- err = vsock_assign_transport(vpending, vsock_sk(sk));
+ vsock_addr_init(&remote_addr, pkt->dg.src.context, pkt->src_port);
+
+ err = vsock_assign_transport(vpending, vsock_sk(sk), &remote_addr);
/* Transport assigned (looking at remote_addr) must be the same
* where we received the request.
*/
@@ -1133,6 +1145,7 @@ vmci_transport_recv_connecting_server(struct sock *listener,
{
struct vsock_sock *vpending;
struct vmci_handle handle;
+ unsigned int vpending_remote_cid;
struct vmci_qp *qpair;
bool is_local;
u32 flags;
@@ -1189,8 +1202,13 @@ vmci_transport_recv_connecting_server(struct sock *listener,
/* vpending->local_addr always has a context id so we do not need to
* worry about VMADDR_CID_ANY in this case.
*/
- is_local =
- vpending->remote_addr.svm_cid == vpending->local_addr.svm_cid;
+ err = vsock_remote_addr_cid(vpending, &vpending_remote_cid);
+ if (err < 0) {
+ skerr = EPROTO;
+ goto destroy;
+ }
+
+ is_local = vpending_remote_cid == vpending->local_addr.svm_cid;
flags = VMCI_QPFLAG_ATTACH_ONLY;
flags |= is_local ? VMCI_QPFLAG_LOCAL : 0;

@@ -1203,7 +1221,7 @@ vmci_transport_recv_connecting_server(struct sock *listener,
flags,
vmci_transport_is_trusted(
vpending,
- vpending->remote_addr.svm_cid));
+ vpending_remote_cid));
if (err < 0) {
vmci_transport_send_reset(pending, pkt);
skerr = -err;
@@ -1277,6 +1295,8 @@ static int
vmci_transport_recv_connecting_client(struct sock *sk,
struct vmci_transport_packet *pkt)
{
+ struct vsock_remote_info *remote_info;
+ struct sockaddr_vm *remote_addr;
struct vsock_sock *vsk;
int err;
int skerr;
@@ -1306,9 +1326,20 @@ vmci_transport_recv_connecting_client(struct sock *sk,
break;
case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE:
case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2:
+ rcu_read_lock();
+ remote_info = vsock_core_get_remote_info(vsk);
+ if (!remote_info) {
+ skerr = EPROTO;
+ err = -EINVAL;
+ rcu_read_unlock();
+ goto destroy;
+ }
+
+ remote_addr = &remote_info->addr;
+
if (pkt->u.size == 0
- || pkt->dg.src.context != vsk->remote_addr.svm_cid
- || pkt->src_port != vsk->remote_addr.svm_port
+ || pkt->dg.src.context != remote_addr->svm_cid
+ || pkt->src_port != remote_addr->svm_port
|| !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)
|| vmci_trans(vsk)->qpair
|| vmci_trans(vsk)->produce_size != 0
@@ -1316,9 +1347,10 @@ vmci_transport_recv_connecting_client(struct sock *sk,
|| vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) {
skerr = EPROTO;
err = -EINVAL;
-
+ rcu_read_unlock();
goto destroy;
}
+ rcu_read_unlock();

err = vmci_transport_recv_connecting_client_negotiate(sk, pkt);
if (err) {
@@ -1379,6 +1411,7 @@ static int vmci_transport_recv_connecting_client_negotiate(
int err;
struct vsock_sock *vsk;
struct vmci_handle handle;
+ unsigned int remote_cid;
struct vmci_qp *qpair;
u32 detach_sub_id;
bool is_local;
@@ -1449,19 +1482,23 @@ static int vmci_transport_recv_connecting_client_negotiate(

/* Make VMCI select the handle for us. */
handle = VMCI_INVALID_HANDLE;
- is_local = vsk->remote_addr.svm_cid == vsk->local_addr.svm_cid;
+
+ err = vsock_remote_addr_cid(vsk, &remote_cid);
+ if (err < 0)
+ goto destroy;
+
+ is_local = remote_cid == vsk->local_addr.svm_cid;
flags = is_local ? VMCI_QPFLAG_LOCAL : 0;

err = vmci_transport_queue_pair_alloc(&qpair,
&handle,
pkt->u.size,
pkt->u.size,
- vsk->remote_addr.svm_cid,
+ remote_cid,
flags,
vmci_transport_is_trusted(
vsk,
- vsk->
- remote_addr.svm_cid));
+ remote_cid));
if (err < 0)
goto destroy;

@@ -1692,6 +1729,7 @@ static int vmci_transport_dgram_bind(struct vsock_sock *vsk,
}

static int vmci_transport_dgram_enqueue(
+ const struct vsock_transport *transport,
struct vsock_sock *vsk,
struct sockaddr_vm *remote_addr,
struct msghdr *msg,
@@ -2052,7 +2090,13 @@ static struct vsock_transport vmci_transport = {

static bool vmci_check_transport(struct vsock_sock *vsk)
{
- return vsk->transport == &vmci_transport;
+ bool retval;
+
+ rcu_read_lock();
+ retval = vsock_core_get_transport(vsk) == &vmci_transport;
+ rcu_read_unlock();
+
+ return retval;
}

static void vmci_vsock_transport_cb(bool is_host)
diff --git a/net/vmw_vsock/vsock_bpf.c b/net/vmw_vsock/vsock_bpf.c
index a3c97546ab84..4d811c9cdf6e 100644
--- a/net/vmw_vsock/vsock_bpf.c
+++ b/net/vmw_vsock/vsock_bpf.c
@@ -148,6 +148,7 @@ static void vsock_bpf_check_needs_rebuild(struct proto *ops)

int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
{
+ const struct vsock_transport *transport;
struct vsock_sock *vsk;

if (restore) {
@@ -157,10 +158,15 @@ int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore
}

vsk = vsock_sk(sk);
- if (!vsk->transport)
+
+ rcu_read_lock();
+ transport = vsock_core_get_transport(vsk);
+ rcu_read_unlock();
+
+ if (!transport)
return -ENODEV;

- if (!vsk->transport->read_skb)
+ if (!transport->read_skb)
return -EOPNOTSUPP;

vsock_bpf_check_needs_rebuild(psock->sk_proto);

--
2.30.2