diff --git a/src/plugins/wireguard/wireguard_output_tun.c b/src/plugins/wireguard/wireguard_output_tun.c index d1b1d6bb8f0..f613d6c0c16 100644 --- a/src/plugins/wireguard/wireguard_output_tun.c +++ b/src/plugins/wireguard/wireguard_output_tun.c @@ -307,6 +307,22 @@ error: return ret; } +static_always_inline void +wg_calc_checksum (vlib_main_t *vm, vlib_buffer_t *b) +{ + int bogus = 0; + u8 ip_ver_out = (*((u8 *) vlib_buffer_get_current (b)) >> 4); + + /* IPv6 UDP checksum is mandatory */ + if (ip_ver_out == 6) + { + ip6_header_t *ip6 = + (ip6_header_t *) ((u8 *) vlib_buffer_get_current (b)); + udp_header_t *udp = ip6_next_header (ip6); + udp->checksum = ip6_tcp_udp_icmp_compute_checksum (vm, b, ip6, &bogus); + } +} + /* is_ip4 - inner header flag */ always_inline uword wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, @@ -555,6 +571,14 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node, /* wg-output-process-ops */ wg_output_process_ops (vm, node, ptd->crypto_ops, sync_bufs, nexts, drop_next); + + int n_left_from_sync_bufs = n_sync; + while (n_left_from_sync_bufs > 0) + { + n_left_from_sync_bufs--; + wg_calc_checksum (vm, sync_bufs[n_left_from_sync_bufs]); + } + vlib_buffer_enqueue_to_next (vm, node, sync_bi, nexts, n_sync); } if (n_async) @@ -627,6 +651,11 @@ wg_output_tun_post (vlib_main_t *vm, vlib_node_runtime_t *node, next[2] = (wg_post_data (b[2]))->next_index; next[3] = (wg_post_data (b[3]))->next_index; + wg_calc_checksum (vm, b[0]); + wg_calc_checksum (vm, b[1]); + wg_calc_checksum (vm, b[2]); + wg_calc_checksum (vm, b[3]); + if (PREDICT_FALSE (node->flags & VLIB_NODE_FLAG_TRACE)) { if (b[0]->flags & VLIB_BUFFER_IS_TRACED) @@ -671,6 +700,8 @@ wg_output_tun_post (vlib_main_t *vm, vlib_node_runtime_t *node, while (n_left > 0) { + wg_calc_checksum (vm, b[0]); + next[0] = (wg_post_data (b[0]))->next_index; if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) && (b[0]->flags & VLIB_BUFFER_IS_TRACED))) diff --git a/src/plugins/wireguard/wireguard_send.c b/src/plugins/wireguard/wireguard_send.c index adfa5cac3de..72fa11034bf 100644 --- a/src/plugins/wireguard/wireguard_send.c +++ b/src/plugins/wireguard/wireguard_send.c @@ -41,7 +41,8 @@ ip46_enqueue_packet (vlib_main_t *vm, u32 bi0, int is_ip4) } static void -wg_buffer_prepend_rewrite (vlib_buffer_t *b0, const u8 *rewrite, u8 is_ip4) +wg_buffer_prepend_rewrite (vlib_main_t *vm, vlib_buffer_t *b0, + const u8 *rewrite, u8 is_ip4) { if (is_ip4) { @@ -72,6 +73,13 @@ wg_buffer_prepend_rewrite (vlib_buffer_t *b0, const u8 *rewrite, u8 is_ip4) hdr6->ip6.payload_length = hdr6->udp.length = clib_host_to_net_u16 (b0->current_length - sizeof (ip6_header_t)); + + /* IPv6 UDP checksum is mandatory */ + int bogus = 0; + ip6_header_t *ip6_0 = &(hdr6->ip6); + hdr6->udp.checksum = + ip6_tcp_udp_icmp_compute_checksum (vm, b0, ip6_0, &bogus); + ASSERT (bogus == 0); } } @@ -93,7 +101,7 @@ wg_create_buffer (vlib_main_t *vm, const u8 *rewrite, const u8 *packet, b0->current_length = packet_len; - wg_buffer_prepend_rewrite (b0, rewrite, is_ip4); + wg_buffer_prepend_rewrite (vm, b0, rewrite, is_ip4); return true; } diff --git a/test/test_wireguard.py b/test/test_wireguard.py index b12330ac5bb..80ebdd89aa6 100644 --- a/test/test_wireguard.py +++ b/test/test_wireguard.py @@ -375,12 +375,13 @@ class VppWgPeer(VppObject): if is_ip6 is False: self._test.assertEqual(p[IP].src, self.itf.src) self._test.assertEqual(p[IP].dst, self.endpoint) + self._test.assert_packet_checksums_valid(p) else: self._test.assertEqual(p[IPv6].src, self.itf.src) self._test.assertEqual(p[IPv6].dst, self.endpoint) + self._test.assert_packet_checksums_valid(p, False) self._test.assertEqual(p[UDP].sport, self.itf.port) self._test.assertEqual(p[UDP].dport, self.port) - self._test.assert_packet_checksums_valid(p) def consume_init(self, p, tx_itf, is_ip6=False, is_mac2=False): self.noise.set_as_responder() @@ -466,17 +467,16 @@ class VppWgPeer(VppObject): def encrypt_transport(self, p): return self.noise.encrypt(bytes(p)) - def validate_encapped(self, rxs, tx, is_ip6=False): + def validate_encapped(self, rxs, tx, is_tunnel_ip6=False, is_transport_ip6=False): for rx in rxs: - if is_ip6 is False: - rx = IP(self.decrypt_transport(rx, is_ip6=is_ip6)) - + rx = self.decrypt_transport(rx, is_tunnel_ip6) + if is_transport_ip6 is False: + rx = IP(rx) # check the original packet is present self._test.assertEqual(rx[IP].dst, tx[IP].dst) self._test.assertEqual(rx[IP].ttl, tx[IP].ttl - 1) else: - rx = IPv6(self.decrypt_transport(rx, is_ip6=is_ip6)) - + rx = IPv6(rx) # check the original packet is present self._test.assertEqual(rx[IPv6].dst, tx[IPv6].dst) self._test.assertEqual(rx[IPv6].hlim, tx[IPv6].hlim - 1) @@ -1222,7 +1222,7 @@ class TestWg(VppTestCase): rxs = self.send_and_expect(self.pg0, [p], self.pg1) # verify the data packet - peer_1.validate_encapped(rxs, p, is_ip6=is_ip6) + peer_1.validate_encapped(rxs, p, is_tunnel_ip6=is_ip6, is_transport_ip6=is_ip6) # remove configs r1.remove_vpp_config() @@ -1246,7 +1246,7 @@ class TestWg(VppTestCase): self._test_wg_peer_roaming_on_data_tmpl(is_async=True, is_ip6=True) def test_wg_peer_resp(self): - """Send handshake response""" + """Send handshake response IPv4 tunnel""" port = 12323 # Create interfaces @@ -1323,6 +1323,83 @@ class TestWg(VppTestCase): peer_1.remove_vpp_config() wg0.remove_vpp_config() + def test_wg_peer_resp_ipv6(self): + """Send handshake response IPv6 tunnel""" + port = 12323 + + # Create interfaces + wg0 = VppWgInterface(self, self.pg1.local_ip6, port).add_vpp_config() + wg0.admin_up() + wg0.config_ip4() + + self.pg_enable_capture(self.pg_interfaces) + self.pg_start() + + peer_1 = VppWgPeer( + self, wg0, self.pg1.remote_ip6, port + 1, ["10.11.3.0/24"] + ).add_vpp_config() + self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1) + + r1 = VppIpRoute( + self, "10.11.3.0", 24, [VppRoutePath("10.11.3.1", wg0.sw_if_index)] + ).add_vpp_config() + + # wait for the peer to send a handshake + rx = self.pg1.get_capture(1, timeout=2) + + # consume the handshake in the noise protocol and + # generate the response + resp = peer_1.consume_init(rx[0], self.pg1, is_ip6=True) + + # send the response, get keepalive + rxs = self.send_and_expect(self.pg1, [resp], self.pg1) + + for rx in rxs: + b = peer_1.decrypt_transport(rx, True) + self.assertEqual(0, len(b)) + + # send a packets that are routed into the tunnel + p = ( + Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) + / IP(src=self.pg0.remote_ip4, dst="10.11.3.2") + / UDP(sport=555, dport=556) + / Raw(b"\x00" * 80) + ) + + rxs = self.send_and_expect(self.pg0, p * 2, self.pg1) + peer_1.validate_encapped(rxs, p, True) + + # send packets into the tunnel, expect to receive them on + # the other side + p = [ + ( + peer_1.mk_tunnel_header(self.pg1, True) + / Wireguard(message_type=4, reserved_zero=0) + / WireguardTransport( + receiver_index=peer_1.sender, + counter=ii, + encrypted_encapsulated_packet=peer_1.encrypt_transport( + ( + IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) + / UDP(sport=222, dport=223) + / Raw() + ) + ), + ) + ) + for ii in range(255) + ] + + rxs = self.send_and_expect(self.pg1, p, self.pg0) + + for rx in rxs: + self.assertEqual(rx[IP].dst, self.pg0.remote_ip4) + self.assertEqual(rx[IP].ttl, 19) + + r1.remove_vpp_config() + peer_1.remove_vpp_config() + wg0.remove_vpp_config() + def test_wg_peer_v4o4(self): """Test v4o4"""