diff --git a/src/plugins/wireguard/test/test_wireguard.py b/src/plugins/wireguard/test/test_wireguard.py index cd124f3e246..77349396ec6 100755 --- a/src/plugins/wireguard/test/test_wireguard.py +++ b/src/plugins/wireguard/test/test_wireguard.py @@ -1,15 +1,24 @@ #!/usr/bin/env python3 """ Wg tests """ +import datetime +import base64 + +from hashlib import blake2s from scapy.packet import Packet from scapy.packet import Raw -from scapy.layers.l2 import Ether +from scapy.layers.l2 import Ether, ARP from scapy.layers.inet import IP, UDP from scapy.contrib.wireguard import Wireguard, WireguardResponse, \ - WireguardInitiation -from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey + WireguardInitiation, WireguardTransport +from cryptography.hazmat.primitives.asymmetric.x25519 import \ + X25519PrivateKey, X25519PublicKey from cryptography.hazmat.primitives.serialization import Encoding, \ PrivateFormat, PublicFormat, NoEncryption +from cryptography.hazmat.primitives.hashes import BLAKE2s, Hash +from cryptography.hazmat.primitives.hmac import HMAC +from cryptography.hazmat.backends import default_backend +from noise.connection import NoiseConnection, Keypair from vpp_ipip_tun_interface import VppIpIpTunInterface from vpp_interface import VppInterface @@ -25,41 +34,48 @@ Wg test. """ +def private_key_bytes(k): + return k.private_bytes(Encoding.Raw, + PrivateFormat.Raw, + NoEncryption()) + + +def public_key_bytes(k): + return k.public_bytes(Encoding.Raw, + PublicFormat.Raw) + + class VppWgInterface(VppInterface): """ VPP WireGuard interface """ - def __init__(self, test, src, port, key=None): + def __init__(self, test, src, port): super(VppWgInterface, self).__init__(test) - self.key = key - if not self.key: - self.generate = True - else: - self.generate = False self.port = port self.src = src + self.private_key = X25519PrivateKey.generate() + self.public_key = self.private_key.public_key() + + def public_key_bytes(self): + return public_key_bytes(self.public_key) + + def private_key_bytes(self): + return private_key_bytes(self.private_key) def add_vpp_config(self): r = self.test.vapi.wireguard_interface_create(interface={ 'user_instance': 0xffffffff, 'port': self.port, 'src_ip': self.src, - 'private_key': self.key_bytes() + 'private_key': private_key_bytes(self.private_key), + 'generate_key': False }) self.set_sw_if_index(r.sw_if_index) self.test.registry.register(self, self.test.logger) return self - def key_bytes(self): - if self.key: - return self.key.private_bytes(Encoding.Raw, - PrivateFormat.Raw, - NoEncryption()) - else: - return bytearray(32) - def remove_vpp_config(self): self.test.vapi.wireguard_interface_delete( sw_if_index=self._sw_if_index) @@ -70,7 +86,7 @@ class VppWgInterface(VppInterface): if t.interface.sw_if_index == self._sw_if_index and \ str(t.interface.src_ip) == self.src and \ t.interface.port == self.port and \ - t.interface.private_key == self.key_bytes(): + t.interface.private_key == private_key_bytes(self.private_key): return True return False @@ -91,6 +107,10 @@ def find_route(test, prefix, table_id=0): return False +NOISE_HANDSHAKE_NAME = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" +NOISE_IDENTIFIER_NAME = b"WireGuard v1 zx2c4 Jason@zx2c4.com" + + class VppWgPeer(VppObject): def __init__(self, @@ -106,9 +126,12 @@ class VppWgPeer(VppObject): self.port = port self.allowed_ips = allowed_ips self.persistent_keepalive = persistent_keepalive + + # remote peer's public self.private_key = X25519PrivateKey.generate() self.public_key = self.private_key.public_key() - self.hash = bytearray(16) + + self.noise = NoiseConnection.from_name(NOISE_HANDSHAKE_NAME) def validate_routing(self): for a in self.allowed_ips: @@ -129,6 +152,7 @@ class VppWgPeer(VppObject): 'sw_if_index': self.itf.sw_if_index, 'persistent_keepalive': self.persistent_keepalive}) self.index = rv.peer_index + self.receiver_index = self.index + 1 self._test.registry.register(self, self._test.logger) self.validate_routing() return self @@ -141,13 +165,7 @@ class VppWgPeer(VppObject): return ("wireguard-peer-%s" % self.index) def public_key_bytes(self): - return self.public_key.public_bytes(Encoding.Raw, - PublicFormat.Raw) - - def private_key_bytes(self): - return self.private_key.private_bytes(Encoding.Raw, - PrivateFormat.Raw, - NoEncryption()) + return public_key_bytes(self.public_key) def query_vpp_config(self): peers = self._test.vapi.wireguard_peers_dump() @@ -167,6 +185,148 @@ class VppWgPeer(VppObject): return True return False + def set_responder(self): + self.noise.set_as_responder() + + def mk_tunnel_header(self, tx_itf): + return (Ether(dst=tx_itf.local_mac, src=tx_itf.remote_mac) / + IP(src=self.endpoint, dst=self.itf.src) / + UDP(sport=self.port, dport=self.itf.port)) + + def noise_init(self, public_key=None): + self.noise.set_prologue(NOISE_IDENTIFIER_NAME) + self.noise.set_psks(psk=bytes(bytearray(32))) + + if not public_key: + public_key = self.itf.public_key + + # local/this private + self.noise.set_keypair_from_private_bytes( + Keypair.STATIC, + private_key_bytes(self.private_key)) + # remote's public + self.noise.set_keypair_from_public_bytes( + Keypair.REMOTE_STATIC, + public_key_bytes(public_key)) + + self.noise.start_handshake() + + def mk_handshake(self, tx_itf, public_key=None): + self.noise.set_as_initiator() + self.noise_init(public_key) + + p = (Wireguard() / WireguardInitiation()) + + p[Wireguard].message_type = 1 + p[Wireguard].reserved_zero = 0 + p[WireguardInitiation].sender_index = self.receiver_index + + # some random data for the message + # lifted from the noise protocol's wireguard example + now = datetime.datetime.now() + tai = struct.pack('!qi', 4611686018427387914 + int(now.timestamp()), + int(now.microsecond * 1e3)) + b = self.noise.write_message(payload=tai) + + # load noise into init message + p[WireguardInitiation].unencrypted_ephemeral = b[0:32] + p[WireguardInitiation].encrypted_static = b[32:80] + p[WireguardInitiation].encrypted_timestamp = b[80:108] + + # generate the mac1 hash + mac_key = blake2s(b'mac1----' + + self.itf.public_key_bytes()).digest() + p[WireguardInitiation].mac1 = blake2s(bytes(p)[0:116], + digest_size=16, + key=mac_key).digest() + p[WireguardInitiation].mac2 = bytearray(16) + + p = (self.mk_tunnel_header(tx_itf) / p) + + return p + + def verify_header(self, p): + self._test.assertEqual(p[IP].src, self.itf.src) + self._test.assertEqual(p[IP].dst, self.endpoint) + self._test.assertEqual(p[UDP].sport, self.itf.port) + self._test.assertEqual(p[UDP].dport, self.port) + self._test.assert_packet_checksums_valid(p) + + def consume_init(self, p, tx_itf): + self.noise.set_as_responder() + self.noise_init(self.itf.public_key) + self.verify_header(p) + + init = Wireguard(p[Raw]) + + self._test.assertEqual(init[Wireguard].message_type, 1) + self._test.assertEqual(init[Wireguard].reserved_zero, 0) + + self.sender = init[WireguardInitiation].sender_index + + # validate the hash + mac_key = blake2s(b'mac1----' + + public_key_bytes(self.public_key)).digest() + mac1 = blake2s(bytes(init)[0:-32], + digest_size=16, + key=mac_key).digest() + self._test.assertEqual(init[WireguardInitiation].mac1, mac1) + + # this passes only unencrypted_ephemeral, encrypted_static, + # encrypted_timestamp fields of the init + payload = self.noise.read_message(bytes(init)[8:-32]) + + # build the response + b = self.noise.write_message() + mac_key = blake2s(b'mac1----' + + public_key_bytes(self.itf.public_key)).digest() + resp = (Wireguard(message_type=2, reserved_zero=0) / + WireguardResponse(sender_index=self.receiver_index, + receiver_index=self.sender, + unencrypted_ephemeral=b[0:32], + encrypted_nothing=b[32:])) + mac1 = blake2s(bytes(resp)[:-32], + digest_size=16, + key=mac_key).digest() + resp[WireguardResponse].mac1 = mac1 + + resp = (self.mk_tunnel_header(tx_itf) / resp) + self._test.assertTrue(self.noise.handshake_finished) + + return resp + + def consume_response(self, p): + self.verify_header(p) + + resp = Wireguard(p[Raw]) + + self._test.assertEqual(resp[Wireguard].message_type, 2) + self._test.assertEqual(resp[Wireguard].reserved_zero, 0) + self._test.assertEqual(resp[WireguardResponse].receiver_index, + self.receiver_index) + + self.sender = resp[Wireguard].sender_index + + payload = self.noise.read_message(bytes(resp)[12:60]) + self._test.assertEqual(payload, b'') + self._test.assertTrue(self.noise.handshake_finished) + + def decrypt_transport(self, p): + self.verify_header(p) + + p = Wireguard(p[Raw]) + self._test.assertEqual(p[Wireguard].message_type, 4) + self._test.assertEqual(p[Wireguard].reserved_zero, 0) + self._test.assertEqual(p[WireguardTransport].receiver_index, + self.receiver_index) + + d = self.noise.decrypt( + p[WireguardTransport].encrypted_encapsulated_packet) + return d + + def encrypt_transport(self, p): + return self.noise.encrypt(bytes(p)) + class TestWg(VppTestCase): """ Wireguard Test Case """ @@ -192,6 +352,7 @@ class TestWg(VppTestCase): super(TestWg, cls).tearDownClass() def test_wg_interface(self): + """ Simple interface creation """ port = 12312 # Create interface @@ -204,7 +365,51 @@ class TestWg(VppTestCase): # delete interface wg0.remove_vpp_config() - def test_wg_peer(self): + def test_handshake_hash(self): + """ test hashing an init message """ + # a init packet generated by linux given the key below + h = "0100000098b9032b" \ + "55cc4b39e73c3d24" \ + "a2a1ab884b524a81" \ + "1808bb86640fb70d" \ + "e93154fec1879125" \ + "ab012624a27f0b75" \ + "c0a2582f438ddb5f" \ + "8e768af40b4ab444" \ + "02f9ff473e1b797e" \ + "80d39d93c5480c82" \ + "a3d4510f70396976" \ + "586fb67300a5167b" \ + "ae6ca3ff3dfd00eb" \ + "59be198810f5aa03" \ + "6abc243d2155ee4f" \ + "2336483900aef801" \ + "08752cd700000000" \ + "0000000000000000" \ + "00000000" + + b = bytearray.fromhex(h) + tgt = Wireguard(b) + + pubb = base64.b64decode("aRuHFTTxICIQNefp05oKWlJv3zgKxb8+WW7JJMh0jyM=") + pub = X25519PublicKey.from_public_bytes(pubb) + + self.assertEqual(pubb, public_key_bytes(pub)) + + # strip the macs and build a new packet + init = b[0:-32] + mac_key = blake2s(b'mac1----' + public_key_bytes(pub)).digest() + init += blake2s(init, + digest_size=16, + key=mac_key).digest() + init += b'\x00' * 16 + + act = Wireguard(init) + + self.assertEqual(tgt, act) + + def test_wg_peer_resp(self): + """ Send handshake response """ wg_output_node_name = '/err/wg-output-tun/' wg_input_node_name = '/err/wg-input/' @@ -213,16 +418,9 @@ class TestWg(VppTestCase): # Create interfaces wg0 = VppWgInterface(self, self.pg1.local_ip4, - port, - key=X25519PrivateKey.generate()).add_vpp_config() - wg1 = VppWgInterface(self, - self.pg2.local_ip4, - port+1).add_vpp_config() + port).add_vpp_config() wg0.admin_up() - wg1.admin_up() - - # Check peer counter - self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0) + wg0.config_ip4() self.pg_enable_capture(self.pg_interfaces) self.pg_start() @@ -236,43 +434,210 @@ class TestWg(VppTestCase): self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1) # wait for the peer to send a handshake - capture = self.pg1.get_capture(1, timeout=2) - handshake = capture[0] + rx = self.pg1.get_capture(1, timeout=2) - self.assertEqual(handshake[IP].src, wg0.src) - self.assertEqual(handshake[IP].dst, peer_1.endpoint) - self.assertEqual(handshake[UDP].sport, wg0.port) - self.assertEqual(handshake[UDP].dport, peer_1.port) - handshake = Wireguard(handshake[Raw]) - self.assertEqual(handshake.message_type, 1) # "initiate") - init = handshake[WireguardInitiation] + # consume the handshake in the noise protocol and + # generate the response + resp = peer_1.consume_init(rx[0], self.pg1) + + # send the response, get keepalive + rxs = self.send_and_expect(self.pg1, [resp], self.pg1) + + for rx in rxs: + b = peer_1.decrypt_transport(rx) + self.assertEqual(0, len(b)) + + # send a packets that are routed into the tunnel + p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) / + IP(src=self.pg0.remote_ip4, dst="10.11.3.2") / + UDP(sport=555, dport=556) / + Raw(b'\x00' * 80)) + + rxs = self.send_and_expect(self.pg0, p * 255, self.pg1) + + for rx in rxs: + rx = IP(peer_1.decrypt_transport(rx)) + # chech the oringial packet is present + self.assertEqual(rx[IP].dst, p[IP].dst) + self.assertEqual(rx[IP].ttl, p[IP].ttl-1) + + # send packets into the tunnel, expect to receive them on + # the other side + p = [(peer_1.mk_tunnel_header(self.pg1) / + Wireguard(message_type=4, reserved_zero=0) / + WireguardTransport( + receiver_index=peer_1.sender, + counter=ii, + encrypted_encapsulated_packet=peer_1.encrypt_transport( + (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) / + UDP(sport=222, dport=223) / + Raw())))) for ii in range(255)] + + rxs = self.send_and_expect(self.pg1, p, self.pg0) + + for rx in rxs: + self.assertEqual(rx[IP].dst, self.pg0.remote_ip4) + self.assertEqual(rx[IP].ttl, 19) + + def test_wg_peer_init(self): + """ Send handshake init """ + wg_output_node_name = '/err/wg-output-tun/' + wg_input_node_name = '/err/wg-input/' + + port = 12323 + + # Create interfaces + wg0 = VppWgInterface(self, + self.pg1.local_ip4, + port).add_vpp_config() + wg0.admin_up() + wg0.config_ip4() + + peer_1 = VppWgPeer(self, + wg0, + self.pg1.remote_ip4, + port+1, + ["10.11.2.0/24", + "10.11.3.0/24"]).add_vpp_config() + self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1) # route a packet into the wg interface # use the allowed-ip prefix + # this is dropped because the peer is not initiated p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) / IP(src=self.pg0.remote_ip4, dst="10.11.3.2") / UDP(sport=555, dport=556) / Raw()) - # rx = self.send_and_expect(self.pg0, [p], self.pg1) - rx = self.send_and_assert_no_replies(self.pg0, [p]) + self.send_and_assert_no_replies(self.pg0, [p]) - self.logger.info(self.vapi.cli("sh error")) - init_sent = wg_output_node_name + "Keypair error" - self.assertEqual(1, self.statistics.get_err_counter(init_sent)) + kp_error = wg_output_node_name + "Keypair error" + self.assertEqual(1, self.statistics.get_err_counter(kp_error)) + + # send a handsake from the peer with an invalid MAC + p = peer_1.mk_handshake(self.pg1) + p[WireguardInitiation].mac1 = b'foobar' + self.send_and_assert_no_replies(self.pg1, [p]) + self.assertEqual(1, self.statistics.get_err_counter( + wg_input_node_name + "Invalid MAC handshake")) + + # send a handsake from the peer but signed by the wrong key. + p = peer_1.mk_handshake(self.pg1, + X25519PrivateKey.generate().public_key()) + self.send_and_assert_no_replies(self.pg1, [p]) + self.assertEqual(1, self.statistics.get_err_counter( + wg_input_node_name + "Peer error")) + + # send a valid handsake init for which we expect a response + p = peer_1.mk_handshake(self.pg1) + + rx = self.send_and_expect(self.pg1, [p], self.pg1) + + peer_1.consume_response(rx[0]) + + # route a packet into the wg interface + # this is dropped because the peer is still not initiated + p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) / + IP(src=self.pg0.remote_ip4, dst="10.11.3.2") / + UDP(sport=555, dport=556) / + Raw()) + self.send_and_assert_no_replies(self.pg0, [p]) + self.assertEqual(2, self.statistics.get_err_counter(kp_error)) + + # send a data packet from the peer through the tunnel + # this completes the handshake + p = (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) / + UDP(sport=222, dport=223) / + Raw()) + d = peer_1.encrypt_transport(p) + p = (peer_1.mk_tunnel_header(self.pg1) / + (Wireguard(message_type=4, reserved_zero=0) / + WireguardTransport(receiver_index=peer_1.sender, + counter=0, + encrypted_encapsulated_packet=d))) + rxs = self.send_and_expect(self.pg1, [p], self.pg0) + + for rx in rxs: + self.assertEqual(rx[IP].dst, self.pg0.remote_ip4) + self.assertEqual(rx[IP].ttl, 19) + + # send a packets that are routed into the tunnel + p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) / + IP(src=self.pg0.remote_ip4, dst="10.11.3.2") / + UDP(sport=555, dport=556) / + Raw(b'\x00' * 80)) + + rxs = self.send_and_expect(self.pg0, p * 255, self.pg1) + + for rx in rxs: + rx = IP(peer_1.decrypt_transport(rx)) + + # chech the oringial packet is present + self.assertEqual(rx[IP].dst, p[IP].dst) + self.assertEqual(rx[IP].ttl, p[IP].ttl-1) + + # send packets into the tunnel, expect to receive them on + # the other side + p = [(peer_1.mk_tunnel_header(self.pg1) / + Wireguard(message_type=4, reserved_zero=0) / + WireguardTransport( + receiver_index=peer_1.sender, + counter=ii+1, + encrypted_encapsulated_packet=peer_1.encrypt_transport( + (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) / + UDP(sport=222, dport=223) / + Raw())))) for ii in range(255)] + + rxs = self.send_and_expect(self.pg1, p, self.pg0) + + for rx in rxs: + self.assertEqual(rx[IP].dst, self.pg0.remote_ip4) + self.assertEqual(rx[IP].ttl, 19) + + peer_1.remove_vpp_config() + wg0.remove_vpp_config() + + def test_wg_multi_peer(self): + """ multiple peer setup """ + port = 12323 + + # Create interfaces + wg0 = VppWgInterface(self, + self.pg1.local_ip4, + port).add_vpp_config() + wg1 = VppWgInterface(self, + self.pg2.local_ip4, + port+1).add_vpp_config() + wg0.admin_up() + wg1.admin_up() + + # Check peer counter + self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0) + + self.pg_enable_capture(self.pg_interfaces) + self.pg_start() # Create many peers on sencond interface NUM_PEERS = 16 self.pg2.generate_remote_hosts(NUM_PEERS) self.pg2.configure_ipv4_neighbors() + self.pg1.generate_remote_hosts(NUM_PEERS) + self.pg1.configure_ipv4_neighbors() - peers = [] + peers_1 = [] + peers_2 = [] for i in range(NUM_PEERS): - peers.append(VppWgPeer(self, - wg1, - self.pg2.remote_hosts[i].ip4, - port+1+i, - ["10.10.%d.4/32" % i]).add_vpp_config()) - self.assertEqual(len(self.vapi.wireguard_peers_dump()), i+2) + peers_1.append(VppWgPeer(self, + wg0, + self.pg1.remote_hosts[i].ip4, + port+1+i, + ["10.0.%d.4/32" % i]).add_vpp_config()) + peers_2.append(VppWgPeer(self, + wg1, + self.pg2.remote_hosts[i].ip4, + port+100+i, + ["10.100.%d.4/32" % i]).add_vpp_config()) + + self.assertEqual(len(self.vapi.wireguard_peers_dump()), NUM_PEERS*2) self.logger.info(self.vapi.cli("show wireguard peer")) self.logger.info(self.vapi.cli("show wireguard interface")) @@ -281,12 +646,12 @@ class TestWg(VppTestCase): self.logger.info(self.vapi.cli("sh ip fib 10.11.3.0")) # remove peers - for p in peers: + for p in peers_1: + self.assertTrue(p.query_vpp_config()) + p.remove_vpp_config() + for p in peers_2: self.assertTrue(p.query_vpp_config()) p.remove_vpp_config() - self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1) - peer_1.remove_vpp_config() - self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0) wg0.remove_vpp_config() - # wg1.remove_vpp_config() + wg1.remove_vpp_config() diff --git a/src/plugins/wireguard/wireguard_cookie.c b/src/plugins/wireguard/wireguard_cookie.c index aa476f792a1..f54ce715906 100755 --- a/src/plugins/wireguard/wireguard_cookie.c +++ b/src/plugins/wireguard/wireguard_cookie.c @@ -86,7 +86,7 @@ cookie_checker_validate_macs (vlib_main_t * vm, cookie_checker_t * cc, len = len - sizeof (message_macs_t); cookie_macs_mac1 (&our_cm, buf, len, cc->cc_mac1_key); - /* If mac1 is invald, we want to drop the packet */ + /* If mac1 is invalid, we want to drop the packet */ if (clib_memcmp (our_cm.mac1, cm->mac1, COOKIE_MAC_SIZE) != 0) return INVALID_MAC; diff --git a/src/plugins/wireguard/wireguard_if.c b/src/plugins/wireguard/wireguard_if.c index ff8ed35477e..ed90146a800 100644 --- a/src/plugins/wireguard/wireguard_if.c +++ b/src/plugins/wireguard/wireguard_if.c @@ -42,11 +42,21 @@ format_wg_if (u8 * s, va_list * args) key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key); s = format (s, " private-key:%s", key); + s = + format (s, " %U", format_hex_bytes, wgi->local.l_private, + NOISE_PUBLIC_KEY_LEN); key_to_base64 (wgi->local.l_public, NOISE_PUBLIC_KEY_LEN, key); s = format (s, " public-key:%s", key); + s = + format (s, " %U", format_hex_bytes, wgi->local.l_public, + NOISE_PUBLIC_KEY_LEN); + + s = format (s, " mac-key: %U", format_hex_bytes, + &wgi->cookie_checker.cc_mac1_key, NOISE_PUBLIC_KEY_LEN); + return (s); } @@ -235,9 +245,6 @@ wg_if_create (u32 user_instance, if (~0 == wg_if->user_instance) wg_if->user_instance = t_idx; - udp_dst_port_info_t *pi = udp_get_dst_port_info (&udp_main, port, UDP_IP4); - if (pi) - return (VNET_API_ERROR_VALUE_EXIST); udp_register_dst_port (vlib_get_main (), port, wg_input_node.index, 1); vec_validate_init_empty (wg_if_index_by_port, port, INDEX_INVALID); @@ -280,16 +287,17 @@ wg_if_delete (u32 sw_if_index) vnet_hw_interface_t *hw = vnet_get_sup_hw_interface (vnm, sw_if_index); if (hw == 0 || hw->dev_class_index != wg_if_device_class.index) - return VNET_API_ERROR_INVALID_SW_IF_INDEX; + return VNET_API_ERROR_INVALID_VALUE; wg_if_t *wg_if; wg_if = wg_if_get (wg_if_find_by_sw_if_index (sw_if_index)); if (NULL == wg_if) - return VNET_API_ERROR_INVALID_SW_IF_INDEX; + return VNET_API_ERROR_INVALID_SW_IF_INDEX_2; - if (wg_if_instance_free (hw->dev_instance) < 0) - return VNET_API_ERROR_INVALID_SW_IF_INDEX; + if (wg_if_instance_free (wg_if->user_instance) < 0) + return VNET_API_ERROR_INVALID_VALUE_2; + udp_unregister_dst_port (vlib_get_main (), wg_if->port, 1); wg_if_index_by_port[wg_if->port] = INDEX_INVALID; vnet_delete_hw_interface (vnm, hw->hw_if_index); pool_put (wg_if_pool, wg_if); diff --git a/src/plugins/wireguard/wireguard_input.c b/src/plugins/wireguard/wireguard_input.c index 832ad031daa..cdd65f87b51 100755 --- a/src/plugins/wireguard/wireguard_input.c +++ b/src/plugins/wireguard/wireguard_input.c @@ -313,12 +313,12 @@ VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm, if (entry) { peer = pool_elt_at_index (wmp->peers, *entry); - if (!peer) - { - next[0] = WG_INPUT_NEXT_ERROR; - b[0]->error = node->errors[WG_INPUT_ERROR_PEER]; - goto out; - } + } + else + { + next[0] = WG_INPUT_NEXT_ERROR; + b[0]->error = node->errors[WG_INPUT_ERROR_PEER]; + goto out; } u16 encr_len = b[0]->current_length - sizeof (message_data_t); diff --git a/src/plugins/wireguard/wireguard_noise.c b/src/plugins/wireguard/wireguard_noise.c index dc7d5060fe5..b47bb5747b9 100755 --- a/src/plugins/wireguard/wireguard_noise.c +++ b/src/plugins/wireguard/wireguard_noise.c @@ -536,7 +536,7 @@ noise_remote_ready (noise_remote_t * r) return ret; } -static void +static bool chacha20poly1305_calc (vlib_main_t * vm, u8 * src, u32 src_len, @@ -580,6 +580,8 @@ chacha20poly1305_calc (vlib_main_t * vm, { clib_memcpy (dst + src_len, op->tag, NOISE_AUTHTAG_LEN); } + + return (op->status == VNET_CRYPTO_OP_STATUS_COMPLETED); } enum noise_state_crypt @@ -668,9 +670,10 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx, /* Decrypt, then validate the counter. We don't want to validate the * counter before decrypting as we do not know the message is authentic * prior to decryption. */ - chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, nonce, - VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, - kp->kp_recv_index); + if (!chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, nonce, + VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, + kp->kp_recv_index)) + goto error; if (!noise_counter_recv (&kp->kp_ctr, nonce)) goto error; @@ -936,8 +939,9 @@ noise_msg_decrypt (vlib_main_t * vm, uint8_t * dst, uint8_t * src, uint8_t hash[NOISE_HASH_LEN]) { /* Nonce always zero for Noise_IK */ - chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0, - VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, key_idx); + if (!chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0, + VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, key_idx)) + return false; noise_mix_hash (hash, src, src_len); return true; } diff --git a/src/plugins/wireguard/wireguard_output_tun.c b/src/plugins/wireguard/wireguard_output_tun.c index daec7a4a2f1..cdfd9d730f6 100755 --- a/src/plugins/wireguard/wireguard_output_tun.c +++ b/src/plugins/wireguard/wireguard_output_tun.c @@ -115,7 +115,8 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm, while (n_left_from > 0) { ip4_udp_header_t *hdr = vlib_buffer_get_current (b[0]); - u8 *plain_data = vlib_buffer_get_current (b[0]) + sizeof (ip4_header_t); + u8 *plain_data = (vlib_buffer_get_current (b[0]) + + sizeof (ip4_udp_header_t)); u16 plain_data_len = clib_net_to_host_u16 (((ip4_header_t *) plain_data)->length); @@ -144,8 +145,8 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm, * Ensure there is enough space to write the encrypted data * into the packet */ - if (PREDICT_FALSE (encrypted_packet_len > WG_OUTPUT_SCRATCH_SIZE) || - PREDICT_FALSE ((b[0]->current_data + encrypted_packet_len) < + if (PREDICT_FALSE (encrypted_packet_len >= WG_OUTPUT_SCRATCH_SIZE) || + PREDICT_FALSE ((b[0]->current_data + encrypted_packet_len) >= vlib_buffer_get_default_data_size (vm))) { b[0]->error = node->errors[WG_OUTPUT_ERROR_TOO_BIG]; diff --git a/src/plugins/wireguard/wireguard_peer.c b/src/plugins/wireguard/wireguard_peer.c index 0dcc4e20e41..04b07d97652 100755 --- a/src/plugins/wireguard/wireguard_peer.c +++ b/src/plugins/wireguard/wireguard_peer.c @@ -380,15 +380,16 @@ format_wg_peer (u8 * s, va_list * va) peer = wg_peer_get (peeri); key_to_base64 (peer->remote.r_public, NOISE_PUBLIC_KEY_LEN, key); - s = format (s, "[%d] key:%=45s endpoint:[%U->%U] %U keep-alive:%d adj:%d", + s = format (s, "[%d] endpoint:[%U->%U] %U keep-alive:%d adj:%d", peeri, - key, format_wg_peer_endpoint, &peer->src, format_wg_peer_endpoint, &peer->dst, format_vnet_sw_if_index_name, vnet_get_main (), peer->wg_sw_if_index, peer->persistent_keepalive_interval, peer->adj_index); - + s = format (s, "\n key:%=s %U", + key, format_hex_bytes, peer->remote.r_public, + NOISE_PUBLIC_KEY_LEN); s = format (s, "\n allowed-ips:"); vec_foreach (allowed_ip, peer->allowed_ips) {