[RFCv3 04/15] selftests: tcp_authopt: Initial sockopt manipulation

From: Leonard Crestez
Date: Tue Aug 24 2021 - 17:36:10 EST


Signed-off-by: Leonard Crestez <cdleonard@xxxxxxxxx>
---
.../tcp_authopt/tcp_authopt_test/conftest.py | 21 ++
.../tcp_authopt_test/linux_tcp_authopt.py | 188 ++++++++++++++++++
.../tcp_authopt/tcp_authopt_test/sockaddr.py | 101 ++++++++++
.../tcp_authopt_test/test_sockopt.py | 74 +++++++
4 files changed, 384 insertions(+)
create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py
create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py
create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py

diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py
new file mode 100644
index 000000000000..c17c8ea2a943
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py
@@ -0,0 +1,21 @@
+# SPDX-License-Identifier: GPL-2.0
+from tcp_authopt_test.linux_tcp_authopt import has_tcp_authopt
+import pytest
+import logging
+from contextlib import ExitStack
+
+logger = logging.getLogger(__name__)
+
+skipif_missing_tcp_authopt = pytest.mark.skipif(
+ not has_tcp_authopt(), reason="Need CONFIG_TCP_AUTHOPT"
+)
+
+
+@pytest.fixture
+def exit_stack():
+ """Return a contextlib.ExitStack as a pytest fixture
+
+ This reduces indentation making code more readable
+ """
+ with ExitStack() as exit_stack:
+ yield exit_stack
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
new file mode 100644
index 000000000000..41374f9851aa
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
@@ -0,0 +1,188 @@
+# SPDX-License-Identifier: GPL-2.0
+"""Python wrapper around linux TCP_AUTHOPT ABI"""
+
+from dataclasses import dataclass
+from ipaddress import IPv4Address, IPv6Address, ip_address
+import socket
+import errno
+import logging
+from .sockaddr import sockaddr_in, sockaddr_in6, sockaddr_storage, sockaddr_unpack
+import typing
+import struct
+
+logger = logging.getLogger(__name__)
+
+
+def BIT(x):
+ return 1 << x
+
+
+TCP_AUTHOPT = 38
+TCP_AUTHOPT_KEY = 39
+
+TCP_AUTHOPT_MAXKEYLEN = 80
+
+TCP_AUTHOPT_FLAG_REJECT_UNEXPECTED = BIT(2)
+
+TCP_AUTHOPT_KEY_DEL = BIT(0)
+TCP_AUTHOPT_KEY_EXCLUDE_OPTS = BIT(1)
+TCP_AUTHOPT_KEY_BIND_ADDR = BIT(2)
+
+TCP_AUTHOPT_ALG_HMAC_SHA_1_96 = 1
+TCP_AUTHOPT_ALG_AES_128_CMAC_96 = 2
+
+
+@dataclass
+class tcp_authopt:
+ """Like linux struct tcp_authopt"""
+
+ flags: int = 0
+ sizeof = 4
+
+ def pack(self) -> bytes:
+ return struct.pack(
+ "I",
+ self.flags,
+ )
+
+ def __bytes__(self):
+ return self.pack()
+
+ @classmethod
+ def unpack(cls, b: bytes):
+ tup = struct.unpack("I", b)
+ return cls(*tup)
+
+
+def set_tcp_authopt(sock, opt: tcp_authopt):
+ return sock.setsockopt(socket.IPPROTO_TCP, TCP_AUTHOPT, bytes(opt))
+
+
+def get_tcp_authopt(sock: socket.socket) -> tcp_authopt:
+ b = sock.getsockopt(socket.IPPROTO_TCP, TCP_AUTHOPT, tcp_authopt.sizeof)
+ return tcp_authopt.unpack(b)
+
+
+class tcp_authopt_key:
+ """Like linux struct tcp_authopt_key"""
+
+ def __init__(
+ self,
+ flags: int = 0,
+ send_id: int = 0,
+ recv_id: int = 0,
+ alg=TCP_AUTHOPT_ALG_HMAC_SHA_1_96,
+ key: bytes = b"",
+ addr: bytes = b"",
+ include_options=None,
+ ):
+ self.flags = flags
+ self.send_id = send_id
+ self.recv_id = recv_id
+ self.alg = alg
+ self.key = key
+ self.addr = addr
+ if include_options is not None:
+ self.include_options = include_options
+
+ def pack(self):
+ if len(self.key) > TCP_AUTHOPT_MAXKEYLEN:
+ raise ValueError(f"Max key length is {TCP_AUTHOPT_MAXKEYLEN}")
+ data = struct.pack(
+ "IBBBB80s",
+ self.flags,
+ self.send_id,
+ self.recv_id,
+ self.alg,
+ len(self.key),
+ self.key,
+ )
+ data += bytes(self.addrbuf.ljust(sockaddr_storage.sizeof, b"\x00"))
+ return data
+
+ def __bytes__(self):
+ return self.pack()
+
+ @property
+ def key(self) -> bytes:
+ return self._key
+
+ @key.setter
+ def key(self, val: typing.Union[bytes, str]) -> bytes:
+ if isinstance(val, str):
+ val = val.encode("utf-8")
+ if len(val) > TCP_AUTHOPT_MAXKEYLEN:
+ raise ValueError(f"Max key length is {TCP_AUTHOPT_MAXKEYLEN}")
+ self._key = val
+ return val
+
+ @property
+ def addr(self):
+ if not self.addrbuf:
+ return None
+ else:
+ return sockaddr_unpack(bytes(self.addrbuf))
+
+ @addr.setter
+ def addr(self, val):
+ if isinstance(val, bytes):
+ if len(val) > sockaddr_storage.sizeof:
+ raise ValueError(f"Must be up to {sockaddr_storage.sizeof}")
+ self.addrbuf = val
+ elif val is None:
+ self.addrbuf = b""
+ elif isinstance(val, str):
+ self.addr = ip_address(val)
+ elif isinstance(val, IPv4Address):
+ self.addr = sockaddr_in(addr=val)
+ elif isinstance(val, IPv6Address):
+ self.addr = sockaddr_in6(addr=val)
+ elif (
+ isinstance(val, sockaddr_in)
+ or isinstance(val, sockaddr_in6)
+ or isinstance(val, sockaddr_storage)
+ ):
+ self.addr = bytes(val)
+ else:
+ raise TypeError(f"Can't handle addr {val}")
+ return self.addr
+
+ @property
+ def include_options(self) -> bool:
+ return (self.flags & TCP_AUTHOPT_KEY_EXCLUDE_OPTS) == 0
+
+ @include_options.setter
+ def include_options(self, value) -> bool:
+ if value:
+ self.flags &= ~TCP_AUTHOPT_KEY_EXCLUDE_OPTS
+ else:
+ self.flags |= TCP_AUTHOPT_KEY_EXCLUDE_OPTS
+
+ @property
+ def delete_flag(self) -> bool:
+ return bool(self.flags & TCP_AUTHOPT_KEY_DEL)
+
+ @delete_flag.setter
+ def delete_flag(self, value) -> bool:
+ if value:
+ self.flags |= TCP_AUTHOPT_KEY_DEL
+ else:
+ self.flags &= ~TCP_AUTHOPT_KEY_DEL
+
+
+def set_tcp_authopt_key(sock, key: tcp_authopt_key):
+ return sock.setsockopt(socket.IPPROTO_TCP, TCP_AUTHOPT_KEY, bytes(key))
+
+
+def has_tcp_authopt() -> bool:
+ """Check is TCP_AUTHOPT is implemented by the OS"""
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ try:
+ optbuf = bytes(4)
+ sock.setsockopt(socket.IPPROTO_TCP, TCP_AUTHOPT, optbuf)
+ return True
+ except OSError as e:
+ if e.errno == errno.ENOPROTOOPT:
+ return False
+ else:
+ raise
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py
new file mode 100644
index 000000000000..f61d0f190a0c
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py
@@ -0,0 +1,101 @@
+# SPDX-License-Identifier: GPL-2.0
+"""pack/unpack wrappers for sockaddr"""
+import socket
+import struct
+from dataclasses import dataclass
+from ipaddress import IPv4Address, IPv6Address
+
+
+@dataclass
+class sockaddr_in:
+ port: int
+ addr: IPv4Address
+ sizeof = 8
+
+ def __init__(self, port=0, addr=None):
+ self.port = port
+ if addr is None:
+ addr = IPv4Address(0)
+ self.addr = IPv4Address(addr)
+
+ def pack(self):
+ return struct.pack("HH4s", socket.AF_INET, self.port, self.addr.packed)
+
+ @classmethod
+ def unpack(cls, buffer):
+ family, port, addr_packed = struct.unpack("HH4s", buffer[:8])
+ if family != socket.AF_INET:
+ raise ValueError(f"Must be AF_INET not {family}")
+ return cls(port, addr_packed)
+
+ def __bytes__(self):
+ return self.pack()
+
+
+@dataclass
+class sockaddr_in6:
+ """Like sockaddr_in6 but for python. Always contains scope_id"""
+
+ port: int
+ addr: IPv6Address
+ flowinfo: int
+ scope_id: int
+ sizeof = 28
+
+ def __init__(self, port=0, addr=None, flowinfo=0, scope_id=0):
+ self.port = port
+ if addr is None:
+ addr = IPv6Address(0)
+ self.addr = IPv6Address(addr)
+ self.flowinfo = flowinfo
+ self.scope_id = scope_id
+
+ def pack(self):
+ return struct.pack(
+ "HHI16sI",
+ socket.AF_INET6,
+ self.port,
+ self.flowinfo,
+ self.addr.packed,
+ self.scope_id,
+ )
+
+ @classmethod
+ def unpack(cls, buffer):
+ family, port, flowinfo, addr_packed, scope_id = struct.unpack(
+ "HHI16sI", buffer[:28]
+ )
+ if family != socket.AF_INET6:
+ raise ValueError(f"Must be AF_INET6 not {family}")
+ return cls(port, addr_packed, flowinfo=flowinfo, scope_id=scope_id)
+
+ def __bytes__(self):
+ return self.pack()
+
+
+@dataclass
+class sockaddr_storage:
+ family: int
+ data: bytes
+ sizeof = 128
+
+ def pack(self):
+ return struct.pack("H126s", self.family, self.data)
+
+ def __bytes__(self):
+ return self.pack()
+
+ @classmethod
+ def unpack(cls, buffer):
+ return cls(*struct.unpack("H126s", buffer))
+
+
+def sockaddr_unpack(buffer: bytes):
+ """Unpack based on family"""
+ family = struct.unpack("H", buffer[:2])[0]
+ if family == socket.AF_INET:
+ return sockaddr_in.unpack(buffer)
+ elif family == socket.AF_INET6:
+ return sockaddr_in6.unpack(buffer)
+ else:
+ return sockaddr_storage.unpack(buffer)
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py
new file mode 100644
index 000000000000..06a05bf8aeec
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py
@@ -0,0 +1,74 @@
+# SPDX-License-Identifier: GPL-2.0
+"""Test TCP_AUTHOPT sockopt API"""
+import errno
+import socket
+import struct
+from ipaddress import IPv4Address, IPv6Address
+
+import pytest
+
+from .linux_tcp_authopt import (
+ set_tcp_authopt,
+ set_tcp_authopt_key,
+ tcp_authopt,
+ tcp_authopt_key,
+)
+from .sockaddr import sockaddr_unpack
+from .conftest import skipif_missing_tcp_authopt
+
+pytestmark = skipif_missing_tcp_authopt
+
+
+def test_authopt_key_pack_noaddr():
+ b = bytes(tcp_authopt_key(key=b"a\x00b"))
+ assert b[7] == 3
+ assert b[8:13] == b"a\x00b\x00\x00"
+
+
+def test_authopt_key_pack_addr():
+ b = bytes(tcp_authopt_key(key=b"a\x00b", addr="10.0.0.1"))
+ assert struct.unpack("H", b[88:90])[0] == socket.AF_INET
+ assert sockaddr_unpack(b[88:]).addr == IPv4Address("10.0.0.1")
+
+
+def test_authopt_key_pack_addr6():
+ b = bytes(tcp_authopt_key(key=b"abc", addr="fd00::1"))
+ assert struct.unpack("H", b[88:90])[0] == socket.AF_INET6
+ assert sockaddr_unpack(b[88:]).addr == IPv6Address("fd00::1")
+
+
+def test_tcp_authopt_key_del_without_active(exit_stack):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ exit_stack.push(sock)
+
+ # nothing happens:
+ key = tcp_authopt_key()
+ assert key.delete_flag is False
+ key.delete_flag = True
+ assert key.delete_flag is True
+ with pytest.raises(OSError) as e:
+ set_tcp_authopt_key(sock, key)
+ assert e.value.errno in [errno.EINVAL, errno.ENOENT]
+
+
+def test_tcp_authopt_key_setdel(exit_stack):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ exit_stack.push(sock)
+ set_tcp_authopt(sock, tcp_authopt())
+
+ # delete returns ENOENT
+ key = tcp_authopt_key()
+ key.delete_flag = True
+ with pytest.raises(OSError) as e:
+ set_tcp_authopt_key(sock, key)
+ assert e.value.errno == errno.ENOENT
+
+ key = tcp_authopt_key(send_id=1, recv_id=2)
+ set_tcp_authopt_key(sock, key)
+ # First delete works fine:
+ key.delete_flag = True
+ set_tcp_authopt_key(sock, key)
+ # Duplicate delete returns ENOENT
+ with pytest.raises(OSError) as e:
+ set_tcp_authopt_key(sock, key)
+ assert e.value.errno == errno.ENOENT
--
2.25.1