/*-
 * Copyright (c) 2004 Robert N. M. Watson
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 * [id for your version control system, if any]
 */

#include <sys/types.h>

#include <net/ethernet.h>

#include <netinet/in.h>
#include <netinet/in_systm.h>
#include <netinet/ip.h>
#include <netinet/tcp.h>

#include <arpa/inet.h>

#include <assert.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "cksum.h"
#include "ip.h"

struct tcp_queue_entry {
	u_char			*tqe_packet;
	u_int			 tqe_packetlen;
	struct tcp_queue_entry	*tqe_next;
};

static struct tcp_queue_entry	*tcp_queue_head, *tcp_queue_tail;
static pthread_mutex_t		 tcp_mutex;
static pthread_cond_t		 tcp_cond;

void
tcp_init(void)
{

	assert(pthread_mutex_init(&tcp_mutex, NULL) == 0);
	assert(pthread_cond_init(&tcp_cond, NULL) == 0);
	tcp_queue_head = tcp_queue_tail = NULL;
}

void
tcp_printhdr(struct ip *ip, struct tcphdr *th)
{
	u_char thflags;

	printf("IP source: %s", inet_ntoa(ip->ip_src));
	printf("  IP dest: %s\n", inet_ntoa(ip->ip_dst));
	printf("TCP source port: %d dest port: %d ", ntohs(th->th_sport),
	    ntohs(th->th_dport));
	thflags = th->th_flags;
	printf("flags: ");
	if (thflags & TH_SYN)
		printf("SYN ");
	if (thflags & TH_ACK)
		printf("ACK ");
	if (thflags & TH_FIN)
		printf("FIN ");
	if (thflags & TH_RST)
		printf("RST ");
	if (thflags & TH_PUSH)
		printf("PUSH ");
	if (thflags & TH_URG)
		printf("URG ");
	if (thflags & TH_ECE)
		printf("ECE ");
	if (thflags & TH_CWR)
		printf("CWR ");
	printf("\n");
}

static void
tcp_enqueue(u_char *packet, u_int packetlen)
{
	struct tcp_queue_entry *tqe;
	u_char *local_packet;

	tqe = malloc(sizeof(*tqe));
	if (tqe == NULL)
		return;
	local_packet = malloc(packetlen);
	if (local_packet == NULL) {
		free(tqe);
		return;
	}
	bcopy(packet, local_packet, packetlen);
	tqe->tqe_packet = local_packet;
	tqe->tqe_packetlen = packetlen;
	tqe->tqe_next = NULL;
	assert(pthread_mutex_lock(&tcp_mutex) == 0);
	if (tcp_queue_tail != NULL) {
		tcp_queue_tail->tqe_next = tqe;
		tcp_queue_tail = tqe;
	} else
		tcp_queue_head = tcp_queue_tail = tqe;
	assert(pthread_cond_signal(&tcp_cond) == 0);
	assert(pthread_mutex_unlock(&tcp_mutex) == 0);
}

/*
 * ACK a TCP packet given the receieved IP and TCP headers.
 */
void
tcp_sendackseq(struct ip *ip, struct tcphdr *th, u_int32_t seq)
{
	struct ether_header *neweh;
	u_int packetlen, len;
	struct tcphdr *newth;
	struct ip *newip;
	u_char *packet;

	len = sizeof(*newip) + sizeof(*newth);
	packetlen = sizeof(*neweh) + len;
	packet = malloc(packetlen);
	if (packet == NULL)
		return;
	neweh = (struct ether_header *)(packet);
	newip = (struct ip *)(neweh + 1);
	newth = (struct tpchdr *)(newip + 1);

	newip->ip_hl = sizeof(*newip) >> 2;
	newip->ip_v = 4;
	newip->ip_tos = 0;
	newip->ip_len = htons(sizeof(*ip) + sizeof(*th));
	newip->ip_id = random();
	newip->ip_off = htons(0);
	newip->ip_ttl = 64;
	newip->ip_p = IPPROTO_TCP;
	newip->ip_sum = htons(0);	/* Will be filled in by ip_output(). */
	newip->ip_src = my_ipaddr;
	newip->ip_dst = ip->ip_src;

	newth->th_off = sizeof(*newth) >> 2;
	newth->th_x2 = 0;
	newth->th_sport = th->th_dport;
	newth->th_dport = th->th_sport;
	if ((th->th_flags & TH_SYN) && (th->th_flags & TH_ACK) == 0) {
		newth->th_ack = ntohl(htonl(th->th_seq) + 1);
		newth->th_seq = seq;
		newth->th_flags = TH_SYN | TH_ACK;
	} else if ((th->th_flags & TH_SYN) && (th->th_flags && TH_ACK)) {
		newth->th_ack = ntohl(htonl(th->th_seq) + 1);
		newth->th_seq = seq;
		newth->th_flags = TH_ACK;
	} else {
		newth->th_ack = ntohl(htonl(th->th_seq));
		newth->th_seq = seq;
		newth->th_flags = TH_ACK;
	}
	newth->th_win = htons(1);
	newth->th_urp = 0;

	newth->th_sum = in_pseudo(newip->ip_src.s_addr, newip->ip_dst.s_addr,
	    htons(sizeof(*newth) + IPPROTO_TCP));
	newth->th_sum = in_cksum((u_short *)newth, sizeof(*newth));

	ip_output(packet, packetlen);

	free(packet);
}

void
tcp_sendack(struct ip *ip, struct tcphdr *th)
{

	tcp_sendackseq(ip, th, htonl(0));
}

void
tcp_sendrstseq(struct ip *ip, struct tcphdr *th, u_int32_t seq)
{
	struct ether_header *neweh;
	u_int packetlen, len;
	struct tcphdr *newth;
	struct ip *newip;
	u_char *packet;

	len = sizeof(*newip) + sizeof(*newth);
	packetlen = sizeof(*neweh) + len;
	packet = malloc(packetlen);
	if (packet == NULL)
		return;
	neweh = (struct ether_header *)(packet);
	newip = (struct ip *)(neweh + 1);
	newth = (struct tpchdr *)(newip + 1);

	newip->ip_hl = sizeof(*newip) >> 2;
	newip->ip_v = 4;
	newip->ip_tos = 0;
	newip->ip_len = htons(sizeof(*ip) + sizeof(*th));
	newip->ip_id = random();
	newip->ip_off = htons(0);
	newip->ip_ttl = 64;
	newip->ip_p = IPPROTO_TCP;
	newip->ip_sum = htons(0);	/* Will be filled in by ip_output(). */
	newip->ip_src = my_ipaddr;
	newip->ip_dst = ip->ip_src;

	newth->th_off = sizeof(*newth) >> 2;
	newth->th_x2 = 0;
	newth->th_sport = th->th_dport;
	newth->th_dport = th->th_sport;
	newth->th_flags = TH_RST | TH_ACK;
	newth->th_ack = htonl(ntohl(th->th_seq) + 1);
	newth->th_seq = seq;
	newth->th_win = htons(0);
	newth->th_urp = 0;

	newth->th_sum = in_pseudo(newip->ip_src.s_addr, newip->ip_dst.s_addr,
	    htons(sizeof(*newth) + IPPROTO_TCP));
	newth->th_sum = in_cksum((u_short *)newth, sizeof(*newth));

	ip_output(packet, packetlen);

	free(packet);
}

void
tcp_sendrst(struct ip *ip, struct tcphdr *th)
{

	tcp_sendrstseq(ip, th, htonl(0));
}

#if 0
static void
tcp_syn(struct ether_header *eh, struct ip *ip, struct tcphdr *th,
    u_int tlen, u_int payloadlen)
{

	tcp_sendack(ip, th);
}

static void
tcp_synack(struct ether_header *eh, struct ip *ip, struct tcphdr *th,
    u_int tlen, u_int payloadlen)
{

	tcp_sendack(ip, th);
}
#endif

void
tcp_input(struct ether_header *eh, struct ip *ip, u_int hlen,
    u_char *payload, u_int payloadlen)
{
	struct tcphdr *th;
	u_int tlen;

	if (payloadlen < sizeof(*th)) {
		printf("tcp_input: truncated TCP header\n");
		return;
	}
	th = (struct tcphdr *)payload;
	tlen = th->th_off << 2;

	if (payloadlen < tlen) {
		printf("tcp_input: truncated extended TCP header\n");
		return;
	}

	if (tlen < sizeof(*th)) {
		printf("tcp_input: TCP offset truncated TCP header\n");
		return;
	}

	tcp_enqueue((u_char *)ip, payloadlen + hlen);
#if 0
	if ((thflags & (TH_SYN | TH_ACK)) == (TH_SYN | TH_ACK)) {
		tcp_synack(eh, ip, th, tlen, payloadlen);
		return;
	} else if (thflags & TH_SYN) {
		tcp_syn(eh, ip, th, tlen, payloadlen);
		return;
	} else
		printf("tcp_input: unhandled type\n");
#endif
}

void
tcp_receive(u_char **packet, u_int *packetlen)
{
	struct tcp_queue_entry *tqe;

	assert(pthread_mutex_lock(&tcp_mutex) == 0);
	while (tcp_queue_head == NULL)
		assert(pthread_cond_wait(&tcp_cond, &tcp_mutex) == 0);
	tqe = tcp_queue_head;
	tcp_queue_head = tqe->tqe_next;
	if (tcp_queue_tail == tqe)
		tcp_queue_tail = NULL;
	tqe->tqe_next = NULL;
	assert(pthread_mutex_unlock(&tcp_mutex) == 0);

	*packet = tqe->tqe_packet;
	*packetlen = tqe->tqe_packetlen;
	free(tqe);
}

int
tcp_receive_poll(u_char **packet, u_int *packetlen)
{
	struct tcp_queue_entry *tqe;

	assert(pthread_mutex_lock(&tcp_mutex) == 0);

	if (tcp_queue_head == NULL) {
		assert(pthread_mutex_unlock(&tcp_mutex) == 0);
		return (0);
	}

	tqe = tcp_queue_head;
	*packet = tqe->tqe_packet;
	*packetlen = tqe->tqe_packetlen;
	tcp_queue_head = tqe->tqe_next;
	if (tcp_queue_tail == tqe)
		tcp_queue_tail = NULL;
	free(tqe);
	assert(pthread_mutex_unlock(&tcp_mutex) == 0);
	return (1);
}
