293 lines
9.4 KiB
Python
293 lines
9.4 KiB
Python
|
#!/usr/bin/env python3
|
||
|
""" Wg tests """
|
||
|
|
||
|
from scapy.packet import Packet
|
||
|
from scapy.packet import Raw
|
||
|
from scapy.layers.l2 import Ether
|
||
|
from scapy.layers.inet import IP, UDP
|
||
|
from scapy.contrib.wireguard import Wireguard, WireguardResponse, \
|
||
|
WireguardInitiation
|
||
|
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
||
|
from cryptography.hazmat.primitives.serialization import Encoding, \
|
||
|
PrivateFormat, PublicFormat, NoEncryption
|
||
|
|
||
|
from vpp_ipip_tun_interface import VppIpIpTunInterface
|
||
|
from vpp_interface import VppInterface
|
||
|
from vpp_object import VppObject
|
||
|
from framework import VppTestCase
|
||
|
from re import compile
|
||
|
import unittest
|
||
|
|
||
|
""" TestWg is a subclass of VPPTestCase classes.
|
||
|
|
||
|
Wg test.
|
||
|
|
||
|
"""
|
||
|
|
||
|
|
||
|
class VppWgInterface(VppInterface):
|
||
|
"""
|
||
|
VPP WireGuard interface
|
||
|
"""
|
||
|
|
||
|
def __init__(self, test, src, port, key=None):
|
||
|
super(VppWgInterface, self).__init__(test)
|
||
|
|
||
|
self.key = key
|
||
|
if not self.key:
|
||
|
self.generate = True
|
||
|
else:
|
||
|
self.generate = False
|
||
|
self.port = port
|
||
|
self.src = src
|
||
|
|
||
|
def add_vpp_config(self):
|
||
|
r = self.test.vapi.wireguard_interface_create(interface={
|
||
|
'user_instance': 0xffffffff,
|
||
|
'port': self.port,
|
||
|
'src_ip': self.src,
|
||
|
'private_key': self.key_bytes()
|
||
|
})
|
||
|
self.set_sw_if_index(r.sw_if_index)
|
||
|
self.test.registry.register(self, self.test.logger)
|
||
|
return self
|
||
|
|
||
|
def key_bytes(self):
|
||
|
if self.key:
|
||
|
return self.key.private_bytes(Encoding.Raw,
|
||
|
PrivateFormat.Raw,
|
||
|
NoEncryption())
|
||
|
else:
|
||
|
return bytearray(32)
|
||
|
|
||
|
def remove_vpp_config(self):
|
||
|
self.test.vapi.wireguard_interface_delete(
|
||
|
sw_if_index=self._sw_if_index)
|
||
|
|
||
|
def query_vpp_config(self):
|
||
|
ts = self.test.vapi.wireguard_interface_dump(sw_if_index=0xffffffff)
|
||
|
for t in ts:
|
||
|
if t.interface.sw_if_index == self._sw_if_index and \
|
||
|
str(t.interface.src_ip) == self.src and \
|
||
|
t.interface.port == self.port and \
|
||
|
t.interface.private_key == self.key_bytes():
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
def __str__(self):
|
||
|
return self.object_id()
|
||
|
|
||
|
def object_id(self):
|
||
|
return "wireguard-%d" % self._sw_if_index
|
||
|
|
||
|
|
||
|
def find_route(test, prefix, table_id=0):
|
||
|
routes = test.vapi.ip_route_dump(table_id, False)
|
||
|
|
||
|
for e in routes:
|
||
|
if table_id == e.route.table_id \
|
||
|
and str(e.route.prefix) == str(prefix):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
class VppWgPeer(VppObject):
|
||
|
|
||
|
def __init__(self,
|
||
|
test,
|
||
|
itf,
|
||
|
endpoint,
|
||
|
port,
|
||
|
allowed_ips,
|
||
|
persistent_keepalive=15):
|
||
|
self._test = test
|
||
|
self.itf = itf
|
||
|
self.endpoint = endpoint
|
||
|
self.port = port
|
||
|
self.allowed_ips = allowed_ips
|
||
|
self.persistent_keepalive = persistent_keepalive
|
||
|
self.private_key = X25519PrivateKey.generate()
|
||
|
self.public_key = self.private_key.public_key()
|
||
|
self.hash = bytearray(16)
|
||
|
|
||
|
def validate_routing(self):
|
||
|
for a in self.allowed_ips:
|
||
|
self._test.assertTrue(find_route(self._test, a))
|
||
|
|
||
|
def validate_no_routing(self):
|
||
|
for a in self.allowed_ips:
|
||
|
self._test.assertFalse(find_route(self._test, a))
|
||
|
|
||
|
def add_vpp_config(self):
|
||
|
rv = self._test.vapi.wireguard_peer_add(
|
||
|
peer={
|
||
|
'public_key': self.public_key_bytes(),
|
||
|
'port': self.port,
|
||
|
'endpoint': self.endpoint,
|
||
|
'n_allowed_ips': len(self.allowed_ips),
|
||
|
'allowed_ips': self.allowed_ips,
|
||
|
'sw_if_index': self.itf.sw_if_index,
|
||
|
'persistent_keepalive': self.persistent_keepalive})
|
||
|
self.index = rv.peer_index
|
||
|
self._test.registry.register(self, self._test.logger)
|
||
|
self.validate_routing()
|
||
|
return self
|
||
|
|
||
|
def remove_vpp_config(self):
|
||
|
self._test.vapi.wireguard_peer_remove(peer_index=self.index)
|
||
|
self.validate_no_routing()
|
||
|
|
||
|
def object_id(self):
|
||
|
return ("wireguard-peer-%s" % self.index)
|
||
|
|
||
|
def public_key_bytes(self):
|
||
|
return self.public_key.public_bytes(Encoding.Raw,
|
||
|
PublicFormat.Raw)
|
||
|
|
||
|
def private_key_bytes(self):
|
||
|
return self.private_key.private_bytes(Encoding.Raw,
|
||
|
PrivateFormat.Raw,
|
||
|
NoEncryption())
|
||
|
|
||
|
def query_vpp_config(self):
|
||
|
peers = self._test.vapi.wireguard_peers_dump()
|
||
|
|
||
|
for p in peers:
|
||
|
if p.peer.public_key == self.public_key_bytes() and \
|
||
|
p.peer.port == self.port and \
|
||
|
str(p.peer.endpoint) == self.endpoint and \
|
||
|
p.peer.sw_if_index == self.itf.sw_if_index and \
|
||
|
len(self.allowed_ips) == p.peer.n_allowed_ips:
|
||
|
self.allowed_ips.sort()
|
||
|
p.peer.allowed_ips.sort()
|
||
|
|
||
|
for (a1, a2) in zip(self.allowed_ips, p.peer.allowed_ips):
|
||
|
if str(a1) != str(a2):
|
||
|
return False
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
class TestWg(VppTestCase):
|
||
|
""" Wireguard Test Case """
|
||
|
|
||
|
error_str = compile(r"Error")
|
||
|
|
||
|
@classmethod
|
||
|
def setUpClass(cls):
|
||
|
super(TestWg, cls).setUpClass()
|
||
|
try:
|
||
|
cls.create_pg_interfaces(range(3))
|
||
|
for i in cls.pg_interfaces:
|
||
|
i.admin_up()
|
||
|
i.config_ip4()
|
||
|
i.resolve_arp()
|
||
|
|
||
|
except Exception:
|
||
|
super(TestWg, cls).tearDownClass()
|
||
|
raise
|
||
|
|
||
|
@classmethod
|
||
|
def tearDownClass(cls):
|
||
|
super(TestWg, cls).tearDownClass()
|
||
|
|
||
|
def test_wg_interface(self):
|
||
|
port = 12312
|
||
|
|
||
|
# Create interface
|
||
|
wg0 = VppWgInterface(self,
|
||
|
self.pg1.local_ip4,
|
||
|
port).add_vpp_config()
|
||
|
|
||
|
self.logger.info(self.vapi.cli("sh int"))
|
||
|
|
||
|
# delete interface
|
||
|
wg0.remove_vpp_config()
|
||
|
|
||
|
def test_wg_peer(self):
|
||
|
wg_output_node_name = '/err/wg-output-tun/'
|
||
|
wg_input_node_name = '/err/wg-input/'
|
||
|
|
||
|
port = 12323
|
||
|
|
||
|
# Create interfaces
|
||
|
wg0 = VppWgInterface(self,
|
||
|
self.pg1.local_ip4,
|
||
|
port,
|
||
|
key=X25519PrivateKey.generate()).add_vpp_config()
|
||
|
wg1 = VppWgInterface(self,
|
||
|
self.pg2.local_ip4,
|
||
|
port+1).add_vpp_config()
|
||
|
wg0.admin_up()
|
||
|
wg1.admin_up()
|
||
|
|
||
|
# Check peer counter
|
||
|
self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
|
||
|
|
||
|
self.pg_enable_capture(self.pg_interfaces)
|
||
|
self.pg_start()
|
||
|
|
||
|
peer_1 = VppWgPeer(self,
|
||
|
wg0,
|
||
|
self.pg1.remote_ip4,
|
||
|
port+1,
|
||
|
["10.11.2.0/24",
|
||
|
"10.11.3.0/24"]).add_vpp_config()
|
||
|
self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
|
||
|
|
||
|
# wait for the peer to send a handshake
|
||
|
capture = self.pg1.get_capture(1, timeout=2)
|
||
|
handshake = capture[0]
|
||
|
|
||
|
self.assertEqual(handshake[IP].src, wg0.src)
|
||
|
self.assertEqual(handshake[IP].dst, peer_1.endpoint)
|
||
|
self.assertEqual(handshake[UDP].sport, wg0.port)
|
||
|
self.assertEqual(handshake[UDP].dport, peer_1.port)
|
||
|
handshake = Wireguard(handshake[Raw])
|
||
|
self.assertEqual(handshake.message_type, 1) # "initiate")
|
||
|
init = handshake[WireguardInitiation]
|
||
|
|
||
|
# route a packet into the wg interface
|
||
|
# use the allowed-ip prefix
|
||
|
p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
|
||
|
IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
|
||
|
UDP(sport=555, dport=556) /
|
||
|
Raw())
|
||
|
# rx = self.send_and_expect(self.pg0, [p], self.pg1)
|
||
|
rx = self.send_and_assert_no_replies(self.pg0, [p])
|
||
|
|
||
|
self.logger.info(self.vapi.cli("sh error"))
|
||
|
init_sent = wg_output_node_name + "Keypair error"
|
||
|
self.assertEqual(1, self.statistics.get_err_counter(init_sent))
|
||
|
|
||
|
# Create many peers on sencond interface
|
||
|
NUM_PEERS = 16
|
||
|
self.pg2.generate_remote_hosts(NUM_PEERS)
|
||
|
self.pg2.configure_ipv4_neighbors()
|
||
|
|
||
|
peers = []
|
||
|
for i in range(NUM_PEERS):
|
||
|
peers.append(VppWgPeer(self,
|
||
|
wg1,
|
||
|
self.pg2.remote_hosts[i].ip4,
|
||
|
port+1+i,
|
||
|
["10.10.%d.4/32" % i]).add_vpp_config())
|
||
|
self.assertEqual(len(self.vapi.wireguard_peers_dump()), i+2)
|
||
|
|
||
|
self.logger.info(self.vapi.cli("show wireguard peer"))
|
||
|
self.logger.info(self.vapi.cli("show wireguard interface"))
|
||
|
self.logger.info(self.vapi.cli("show adj 37"))
|
||
|
self.logger.info(self.vapi.cli("sh ip fib 172.16.3.17"))
|
||
|
self.logger.info(self.vapi.cli("sh ip fib 10.11.3.0"))
|
||
|
|
||
|
# remove peers
|
||
|
for p in peers:
|
||
|
self.assertTrue(p.query_vpp_config())
|
||
|
p.remove_vpp_config()
|
||
|
self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
|
||
|
peer_1.remove_vpp_config()
|
||
|
self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
|
||
|
|
||
|
wg0.remove_vpp_config()
|
||
|
# wg1.remove_vpp_config()
|