diff --git a/src/plugins/cnat/cnat_node.h b/src/plugins/cnat/cnat_node.h index d81f6745bc4..549eeae4416 100644 --- a/src/plugins/cnat/cnat_node.h +++ b/src/plugins/cnat/cnat_node.h @@ -334,6 +334,10 @@ cnat_translation_icmp4_error (ip4_header_t *outer_ip4, icmp46_header_t *icmp, cnat_ip4_translate_l4 (ip4, udp, &inner_l4_sum, new_addr, new_port, 0 /* flags */); tcp->checksum = ip_csum_fold (inner_l4_sum); + + /* TCP checksum changed */ + sum = ip_csum_update (sum, inner_l4_old_sum, inner_l4_sum, ip4_header_t, + checksum); } else if (ip4->protocol == IP_PROTOCOL_UDP) { @@ -341,14 +345,40 @@ cnat_translation_icmp4_error (ip4_header_t *outer_ip4, icmp46_header_t *icmp, cnat_ip4_translate_l4 (ip4, udp, &inner_l4_sum, new_addr, new_port, 0 /* flags */); udp->checksum = ip_csum_fold (inner_l4_sum); + + /* UDP checksum changed */ + sum = ip_csum_update (sum, inner_l4_old_sum, inner_l4_sum, ip4_header_t, + checksum); + } + else if (ip4->protocol == IP_PROTOCOL_ICMP) + { + icmp46_header_t *icmp = (icmp46_header_t *) udp; + if (icmp_type_is_echo (icmp->type)) + { + u16 old_port; + cnat_echo_header_t *echo = (cnat_echo_header_t *) (icmp + 1); + inner_l4_old_sum = inner_l4_sum = icmp->checksum; + + old_port = echo->identifier; + echo->identifier = new_port[VLIB_RX]; + + inner_l4_sum = ip_csum_update ( + inner_l4_sum, old_port, new_port[VLIB_RX], udp_header_t, src_port); + + icmp->checksum = ip_csum_fold (inner_l4_sum); + + sum = ip_csum_update (sum, old_port, new_port[VLIB_RX], udp_header_t, + src_port); + /* checksum changed */ + sum = ip_csum_update (sum, inner_l4_old_sum, inner_l4_sum, + ip4_header_t, checksum); + } + old_port[VLIB_TX] = 0; + old_port[VLIB_RX] = 0; } else return; - /* UDP/TCP checksum changed */ - sum = ip_csum_update (sum, inner_l4_old_sum, inner_l4_sum, - ip4_header_t, checksum); - /* UDP/TCP Ports changed */ if (old_port[VLIB_TX] && new_port[VLIB_TX]) sum = ip_csum_update (sum, old_port[VLIB_TX], new_port[VLIB_TX], @@ -569,6 +599,17 @@ cnat_translation_icmp6_error (ip6_header_t * outer_ip6, cnat_ip6_translate_l4 (ip6, udp, &inner_l4_sum, new_addr, new_port, 0 /* oflags */); tcp->checksum = ip_csum_fold (inner_l4_sum); + + /* TCP checksum changed */ + sum = ip_csum_update (sum, inner_l4_old_sum, inner_l4_sum, ip4_header_t, + checksum); + + /* TCP Ports changed */ + sum = ip_csum_update (sum, old_port[VLIB_TX], new_port[VLIB_TX], + tcp_header_t, dst_port); + + sum = ip_csum_update (sum, old_port[VLIB_RX], new_port[VLIB_RX], + tcp_header_t, src_port); } else if (ip6->protocol == IP_PROTOCOL_UDP) { @@ -576,21 +617,61 @@ cnat_translation_icmp6_error (ip6_header_t * outer_ip6, cnat_ip6_translate_l4 (ip6, udp, &inner_l4_sum, new_addr, new_port, 0 /* oflags */); udp->checksum = ip_csum_fold (inner_l4_sum); + + /* UDP checksum changed */ + sum = ip_csum_update (sum, inner_l4_old_sum, inner_l4_sum, ip4_header_t, + checksum); + + /* UDP Ports changed */ + sum = ip_csum_update (sum, old_port[VLIB_TX], new_port[VLIB_TX], + udp_header_t, dst_port); + + sum = ip_csum_update (sum, old_port[VLIB_RX], new_port[VLIB_RX], + udp_header_t, src_port); + } + else if (ip6->protocol == IP_PROTOCOL_ICMP6) + { + /* Update ICMP6 checksum */ + icmp46_header_t *inner_icmp = (icmp46_header_t *) udp; + inner_l4_old_sum = inner_l4_sum = inner_icmp->checksum; + if (icmp6_type_is_echo (inner_icmp->type)) + { + cnat_echo_header_t *echo = (cnat_echo_header_t *) (inner_icmp + 1); + u16 old_port = echo->identifier; + echo->identifier = new_port[VLIB_RX]; + inner_l4_sum = ip_csum_update ( + inner_l4_sum, old_port, new_port[VLIB_RX], udp_header_t, src_port); + + sum = ip_csum_update (sum, old_port, new_port[VLIB_RX], udp_header_t, + src_port); + } + + inner_l4_sum = + ip_csum_add_even (inner_l4_sum, new_addr[VLIB_TX].as_u64[0]); + inner_l4_sum = + ip_csum_add_even (inner_l4_sum, new_addr[VLIB_TX].as_u64[1]); + inner_l4_sum = + ip_csum_sub_even (inner_l4_sum, ip6->dst_address.as_u64[0]); + inner_l4_sum = + ip_csum_sub_even (inner_l4_sum, ip6->dst_address.as_u64[1]); + + inner_l4_sum = + ip_csum_add_even (inner_l4_sum, new_addr[VLIB_RX].as_u64[0]); + inner_l4_sum = + ip_csum_add_even (inner_l4_sum, new_addr[VLIB_RX].as_u64[1]); + inner_l4_sum = + ip_csum_sub_even (inner_l4_sum, ip6->src_address.as_u64[0]); + inner_l4_sum = + ip_csum_sub_even (inner_l4_sum, ip6->src_address.as_u64[1]); + inner_icmp->checksum = ip_csum_fold (inner_l4_sum); + + /* Update ICMP6 checksum change */ + sum = ip_csum_update (sum, inner_l4_old_sum, inner_l4_sum, ip4_header_t, + checksum); } else return; - /* UDP/TCP checksum changed */ - sum = ip_csum_update (sum, inner_l4_old_sum, inner_l4_sum, ip4_header_t, - checksum); - - /* UDP/TCP Ports changed */ - sum = ip_csum_update (sum, old_port[VLIB_TX], new_port[VLIB_TX], - udp_header_t, dst_port); - - sum = ip_csum_update (sum, old_port[VLIB_RX], new_port[VLIB_RX], - udp_header_t, src_port); - cnat_ip6_translate_l3 (ip6, new_addr); /* IP src/dst addr changed */ sum = ip_csum_add_even (sum, new_addr[VLIB_TX].as_u64[0]); @@ -677,15 +758,35 @@ cnat_session_make_key (vlib_buffer_t *b, ip_address_family_t af, if (icmp_type_is_error_message (icmp->type)) { ip4 = (ip4_header_t *) (icmp + 2); /* Use inner packet */ - udp = (udp_header_t *) (ip4 + 1); - /* Swap dst & src for search as ICMP payload is reversed */ - ip46_address_set_ip4 (&session->key.cs_ip[VLIB_RX], - &ip4->dst_address); - ip46_address_set_ip4 (&session->key.cs_ip[VLIB_TX], - &ip4->src_address); - session->key.cs_proto = ip4->protocol; - session->key.cs_port[VLIB_TX] = udp->src_port; - session->key.cs_port[VLIB_RX] = udp->dst_port; + if (PREDICT_FALSE (ip4->protocol == IP_PROTOCOL_ICMP)) + { + icmp = (icmp46_header_t *) (ip4 + 1); + if (icmp_type_is_echo (icmp->type)) + { + cnat_echo_header_t *echo = + (cnat_echo_header_t *) (icmp + 1); + ip46_address_set_ip4 (&session->key.cs_ip[VLIB_RX], + &ip4->dst_address); + ip46_address_set_ip4 (&session->key.cs_ip[VLIB_TX], + &ip4->src_address); + session->key.cs_proto = ip4->protocol; + session->key.cs_port[VLIB_TX] = echo->identifier; + session->key.cs_port[VLIB_RX] = echo->identifier; + } + } + else + { + udp = (udp_header_t *) (ip4 + 1); + /* Swap dst & src for search as ICMP payload is reversed + */ + ip46_address_set_ip4 (&session->key.cs_ip[VLIB_RX], + &ip4->dst_address); + ip46_address_set_ip4 (&session->key.cs_ip[VLIB_TX], + &ip4->src_address); + session->key.cs_proto = ip4->protocol; + session->key.cs_port[VLIB_TX] = udp->src_port; + session->key.cs_port[VLIB_RX] = udp->dst_port; + } } else if (icmp_type_is_echo (icmp->type)) { @@ -738,15 +839,35 @@ cnat_session_make_key (vlib_buffer_t *b, ip_address_family_t af, if (icmp6_type_is_error_message (icmp->type)) { ip6 = (ip6_header_t *) (icmp + 2); /* Use inner packet */ - udp = (udp_header_t *) (ip6 + 1); - /* Swap dst & src for search as ICMP payload is reversed */ - ip46_address_set_ip6 (&session->key.cs_ip[VLIB_RX], - &ip6->dst_address); - ip46_address_set_ip6 (&session->key.cs_ip[VLIB_TX], - &ip6->src_address); - session->key.cs_proto = ip6->protocol; - session->key.cs_port[VLIB_TX] = udp->src_port; - session->key.cs_port[VLIB_RX] = udp->dst_port; + if (PREDICT_FALSE (ip6->protocol == IP_PROTOCOL_ICMP6)) + { + icmp = (icmp46_header_t *) (ip6 + 1); + if (icmp6_type_is_echo (icmp->type)) + { + cnat_echo_header_t *echo = + (cnat_echo_header_t *) (icmp + 1); + ip46_address_set_ip6 (&session->key.cs_ip[VLIB_RX], + &ip6->dst_address); + ip46_address_set_ip6 (&session->key.cs_ip[VLIB_TX], + &ip6->src_address); + session->key.cs_proto = ip6->protocol; + session->key.cs_port[VLIB_TX] = echo->identifier; + session->key.cs_port[VLIB_RX] = echo->identifier; + } + } + else + { + udp = (udp_header_t *) (ip6 + 1); + /* Swap dst & src for search as ICMP payload is reversed + */ + ip46_address_set_ip6 (&session->key.cs_ip[VLIB_RX], + &ip6->dst_address); + ip46_address_set_ip6 (&session->key.cs_ip[VLIB_TX], + &ip6->src_address); + session->key.cs_proto = ip6->protocol; + session->key.cs_port[VLIB_TX] = udp->src_port; + session->key.cs_port[VLIB_RX] = udp->dst_port; + } } else if (icmp6_type_is_echo (icmp->type)) { diff --git a/test/test_cnat.py b/test/test_cnat.py index 9e979a4e09e..8d8f3210577 100644 --- a/test/test_cnat.py +++ b/test/test_cnat.py @@ -11,9 +11,10 @@ from config import config from scapy.packet import Raw from scapy.layers.l2 import Ether from scapy.layers.inet import IP, UDP, TCP, ICMP -from scapy.layers.inet import IPerror, TCPerror, UDPerror +from scapy.layers.inet import IPerror, TCPerror, UDPerror, ICMPerror from scapy.layers.inet6 import IPv6, IPerror6, ICMPv6DestUnreach from scapy.layers.inet6 import ICMPv6EchoRequest, ICMPv6EchoReply +from scapy.layers.inet6 import ICMPv6TimeExceeded from ipaddress import ip_network @@ -760,12 +761,14 @@ class TestCNatSourceNAT(CnatCommonTestCase): self.sourcenat_test_tcp_udp_conf(TCP, is_v6=True) self.sourcenat_test_tcp_udp_conf(UDP, is_v6=True) self.sourcenat_test_icmp_echo_conf(is_v6=True) + self.sourcenat_test_icmp_traceroute_conf(is_v6=True) def test_snat_v4(self): # """ CNat Source Nat v4 """ self.sourcenat_test_tcp_udp_conf(TCP) self.sourcenat_test_tcp_udp_conf(UDP) self.sourcenat_test_icmp_echo_conf() + self.sourcenat_test_icmp_traceroute_conf() def sourcenat_test_icmp_echo_conf(self, is_v6=False): ctx = CnatTestContext(self, ICMP, is_v6=is_v6) @@ -774,6 +777,125 @@ class TestCNatSourceNAT(CnatCommonTestCase): ctx.cnat_expect(self.pg2, 0, None, self.pg1, 0, 8) ctx.cnat_send_return().cnat_expect_return() + def sourcenat_test_icmp_traceroute_conf(self, is_v6=False): + # IPv4 ICMP + if not is_v6: + # Create an ICMP traceroute packet with TTL set to 1. + # The CNAT translates the packet, but the NATted packet is dropped + # due to the TTL of 1. An ICMP Time Exceeded message is sent + # to the source (which is the NATted address). + # The packet will be translated once more to the original + # source IP address. + + icmp = ( + Ether(src=self.pg0.remote_mac, dst=self.pg0.local_mac) + / IP( + ttl=1, + src=self.pg0.remote_hosts[0].ip4, + dst=self.pg1.remote_hosts[0].ip4, + ) + / ICMP(id=0xFEED, type=8) # ICMP Type Echo Request + / Raw() + ) + + self.rxs = self.send_and_expect(self.pg0, icmp, self.pg0) + + for rx in self.rxs: + self.assert_packet_checksums_valid(rx) + self.assertEqual(rx[IP].dst, self.pg0.remote_hosts[0].ip4) + self.assertEqual(rx[IP].src, "172.16.1.1") + self.assertEqual(rx[ICMP].type, 11) # ICMP Type 11 (Time Exceeded) + self.assertEqual( + rx[ICMP].code, 0 + ) # ICMP Code 0 (TTL Zero During Transit) + inner = rx[ICMP].payload + self.assertEqual(inner[IPerror].src, self.pg0.remote_hosts[0].ip4) + self.assertEqual(inner[IPerror].dst, self.pg1.remote_hosts[0].ip4) + self.assertEqual(inner[ICMPerror].type, 8) # ICMP Echo Request + self.assertEqual(inner[ICMPerror].id, 0xFEED) + + # source ---> NATted Transit ---> Transit 2 ... ---> Transit N ---> Destination + # Simulate an ICMP Time Exceeded message arriving at the NATted Transit + # from the Transit N-2 node. This occurs because the NATted packet + # is dropped due to a TTL of 1. + # An ICMP Time Exceeded message is sent back to the source + # (initially the NATted address). The CNAT then translates the message + # back to the original source IP address. + + # For ICMP based traffic, snat session uses identifier for session key. + # snat allocates a new identifier. To hit the snat session from Transit N-2 + # to NATed Transit, packet should use snat allocated identifier. To get the + # snat allocated identifier, echo request will be sent and captured at the + # destination, taken out the identifier from the packet and use it to set + # the identifier in the ICMP time exceed packet + icmp[IP].ttl = 64 + rxs = self.send_and_expect(self.pg0, icmp, self.pg1) + + icmp_error = ( + Ether(src=self.pg1.remote_mac, dst=self.pg1.local_mac) + / IP(src="172.16.1.1", dst=self.pg2.remote_hosts[0].ip4) + / ICMP(type=11, code=0) + / IPerror( + src=self.pg2.remote_hosts[0].ip4, dst=self.pg1.remote_hosts[0].ip4 + ) + / ICMPerror(id=rxs[0][ICMP].id, type=8) + / Raw() + ) + + self.rxs = self.send_and_expect(self.pg1, icmp_error, self.pg0) + for rx in self.rxs: + self.assert_packet_checksums_valid(rx) + self.assertEqual(rx[IP].dst, self.pg0.remote_hosts[0].ip4) + self.assertEqual(rx[IP].src, "172.16.1.1") + self.assertEqual(rx[ICMP].type, 11) # ICMP Type 11 (Time Exceeded) + self.assertEqual( + rx[ICMP].code, 0 + ) # ICMP Code 0 (TTL Zero During Transit) + inner = rx[ICMP].payload + self.assertEqual(inner[IPerror].src, self.pg0.remote_hosts[0].ip4) + self.assertEqual(inner[IPerror].dst, self.pg1.remote_hosts[0].ip4) + self.assertEqual(inner[ICMPerror].type, 8) # ICMP Echo Request + self.assertEqual(inner[ICMPerror].id, 0xFEED) + + # IPv6 ICMPv6 + if is_v6: + + # Create an ICMPv6 traceroute packet with Hop Limit set to 1. + # The CNAT translates the packet, but the NATted packet is dropped + # due to the Hop Limit of 1. An ICMPv6 Time Exceeded message is sent + # back to the source (which is the NATted address). + # The CNAT translates the message once more to restore + # the original source IPv6 address. + icmp6 = ( + Ether(src=self.pg0.remote_mac, dst=self.pg0.local_mac) + / IPv6( + hlim=1, + src=self.pg0.remote_hosts[0].ip6, + dst=self.pg1.remote_hosts[0].ip6, + ) + / ICMPv6EchoRequest(id=0xFEED) + / Raw() + ) + self.rxs = self.send_and_expect(self.pg0, icmp6, self.pg0) + + for rx in self.rxs: + self.assert_packet_checksums_valid(rx) + self.assertEqual(rx[IPv6].dst, self.pg0.remote_hosts[0].ip6) + self.assertEqual(rx[IPv6].src, "fd01:1::1") + self.assertEqual( + rx[ICMPv6TimeExceeded].type, 3 + ) # ICMPv6 Type 3 (Time Exceeded) + self.assertEqual( + rx[ICMPv6TimeExceeded].code, 0 + ) # ICMPv6 Code 0 (TTL Zero During Transit) + inner = rx[ICMPv6TimeExceeded].payload + self.assertEqual(inner[IPerror6].src, self.pg0.remote_hosts[0].ip6) + self.assertEqual(inner[IPerror6].dst, self.pg1.remote_hosts[0].ip6) + self.assertEqual( + inner[ICMPv6EchoRequest].type, 128 + ) # ICMPv6 Echo Request + self.assertEqual(inner[ICMPv6EchoRequest].id, 0xFEED) + def sourcenat_test_tcp_udp_conf(self, L4PROTO, is_v6=False): ctx = CnatTestContext(self, L4PROTO, is_v6) # we should source NAT @@ -823,7 +945,7 @@ class TestCNatSourceNAT(CnatCommonTestCase): @unittest.skipIf("cnat" in config.excluded_plugins, "Exclude CNAT plugin tests") class TestCNatDHCP(CnatCommonTestCase): - """CNat Translation""" + """CNat DHCP""" @classmethod def setUpClass(cls):