zapret/nfq/checksum.c
2021-03-04 14:30:38 +03:00

139 lines
3.3 KiB
C

#define _GNU_SOURCE
#include "checksum.h"
#include <netinet/in.h>
//#define htonll(x) ((1==htonl(1)) ? (x) : ((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32))
//#define ntohll(x) ((1==ntohl(1)) ? (x) : ((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32))
static uint16_t from64to16(uint64_t x)
{
uint32_t u = (uint32_t)(uint16_t)x + (uint16_t)(x>>16) + (uint16_t)(x>>32) + (uint16_t)(x>>48);
return (uint16_t)u + (uint16_t)(u>>16);
}
static uint16_t do_csum(const uint8_t * buff, size_t len)
{
uint8_t odd;
size_t count;
uint64_t result,w,carry=0;
uint16_t u16;
if (len <= 0) return 0;
odd = (uint8_t)(1 & (size_t)buff);
if (odd)
{
// any endian compatible
u16 = 0;
*((uint8_t*)&u16+1) = *buff;
result = u16;
len--;
buff++;
}
else
result = 0;
count = len >> 1; /* nr of 16-bit words.. */
if (count)
{
if (2 & (size_t) buff)
{
result += *(uint16_t *) buff;
count--;
len -= 2;
buff += 2;
}
count >>= 1; /* nr of 32-bit words.. */
if (count)
{
if (4 & (size_t) buff)
{
result += *(uint32_t *) buff;
count--;
len -= 4;
buff += 4;
}
count >>= 1; /* nr of 64-bit words.. */
if (count)
{
do
{
w = *(uint64_t *) buff;
count--;
buff += 8;
result += carry;
result += w;
carry = (w > result);
} while (count);
result += carry;
result = (result & 0xffffffff) + (result >> 32);
}
if (len & 4)
{
result += *(uint32_t *) buff;
buff += 4;
}
}
if (len & 2)
{
result += *(uint16_t *) buff;
buff += 2;
}
}
if (len & 1)
{
// any endian compatible
u16 = 0;
*(uint8_t*)&u16 = *buff;
result += u16;
}
u16 = from64to16(result);
if (odd) u16 = ((u16 >> 8) & 0xff) | ((u16 & 0xff) << 8);
return u16;
}
uint16_t csum_partial(const void *buff, size_t len)
{
return do_csum(buff,len);
}
uint16_t csum_tcpudp_magic(uint32_t saddr, uint32_t daddr, size_t len, uint8_t proto, uint16_t sum)
{
return ~from64to16((uint64_t)saddr + daddr + sum + htonl(len+proto));
}
uint16_t ip4_compute_csum(const void *buff, size_t len)
{
return ~from64to16(do_csum(buff,len));
}
void ip4_fix_checksum(struct ip *ip)
{
ip->ip_sum = 0;
ip->ip_sum = ip4_compute_csum(ip, ip->ip_hl<<2);
}
uint16_t csum_ipv6_magic(const void *saddr, const void *daddr, size_t len, uint8_t proto, uint16_t sum)
{
uint64_t a = (uint64_t)sum + htonl(len+proto) +
*(uint32_t*)saddr + *((uint32_t*)saddr+1) + *((uint32_t*)saddr+2) + *((uint32_t*)saddr+3) +
*(uint32_t*)daddr + *((uint32_t*)daddr+1) + *((uint32_t*)daddr+2) + *((uint32_t*)daddr+3);
return ~from64to16(a);
}
void tcp4_fix_checksum(struct tcphdr *tcp,size_t len, const struct in_addr *src_addr, const struct in_addr *dest_addr)
{
tcp->th_sum = 0;
tcp->th_sum = csum_tcpudp_magic(src_addr->s_addr,dest_addr->s_addr,len,IPPROTO_TCP,csum_partial(tcp, len));
}
void tcp6_fix_checksum(struct tcphdr *tcp,size_t len, const struct in6_addr *src_addr, const struct in6_addr *dest_addr)
{
tcp->th_sum = 0;
tcp->th_sum = csum_ipv6_magic(src_addr,dest_addr,len,IPPROTO_TCP,csum_partial(tcp, len));
}
void tcp_fix_checksum(struct tcphdr *tcp,size_t len,const struct ip *ip,const struct ip6_hdr *ip6hdr)
{
if (ip)
tcp4_fix_checksum(tcp, len, &ip->ip_src, &ip->ip_dst);
else if (ip6hdr)
tcp6_fix_checksum(tcp, len, &ip6hdr->ip6_src, &ip6hdr->ip6_dst);
}