/*
 * INET		An implementation of the TCP/IP protocol suite for the LINUX
 *		operating system.  INET is implemented using the  BSD Socket
 *		interface as the means of communication with the user level.
 *
 *		The User Datagram Protocol (UDP).
 *
 * Version:	@(#)udp.c	1.0.4	05/07/93
 *
 * Authors:	Ross Biro, <bir7@leland.Stanford.Edu>
 *		Fred N. van Kempen, <waltje@uWalt.NL.Mugnet.ORG>
 *
 *		This program is free software; you can redistribute it and/or
 *		modify it under the terms of the GNU General Public License
 *		as published by the Free Software Foundation; either version
 *		2 of the License, or (at your option) any later version.
 */
#include <asm/system.h>
#include <asm/segment.h>
#include <linux/types.h>
#include <linux/sched.h>
#include <linux/fcntl.h>
#include <linux/socket.h>
#include <linux/sockios.h>
#include <linux/in.h>
#include <linux/errno.h>
#include <linux/timer.h>
#include <linux/termios.h>
#include <linux/mm.h>
#include "inet.h"
#include "timer.h"
#include "dev.h"
#include "ip.h"
#include "protocol.h"
#include "tcp.h"
#include "skbuff.h"
#include "sock.h"
#include "udp.h"
#include "icmp.h"


#undef	UDP_DEBUG
#ifdef	UDP_DEBUG
#   define PRINTK(x)	printk x
#else
#   define PRINTK(x)	/**/
#endif


#define min(a,b)	((a)<(b)?(a):(b))


static void
print_uh(struct udphdr *uh)
{
#ifdef UDP_DEBUG
  if (uh == NULL) {
	printk("(NULL)\n");
	return;
  }
  printk("source = %d, dest = %d\n", ntohs(uh->source), ntohs(uh->dest));
  printk("len = %d, check = %d\n", ntohs(uh->len), ntohs(uh->check));
#endif
}


int
udp_select(struct sock *sk, int sel_type, select_table *wait)
{
  select_wait(sk->sleep, wait);
  switch(sel_type) {
	case SEL_IN:
		if (sk->rqueue != NULL) {
			return(1);
		}
		return(0);

	case SEL_OUT:
		if (sk->prot->wspace(sk) >= MIN_WRITE_SPACE) {
			return(1);
		}
		return(0);
	
	case SEL_EX:
		if (sk->err) return(1); /* can this ever happen? */
		return(0);
  }
  return(0);
}


/*
 * This routine is called by the ICMP module when it gets some
 * sort of error condition.  If err < 0 then the socket should
 * be closed and the error returned to the user.  If err > 0
 * it's just the icmp type << 8 | icmp code.  
 * header points to the first 8 bytes of the tcp header.  We need
 * to find the appropriate port.
 */
void
udp_err(int err, unsigned char *header, unsigned long daddr,
	unsigned long saddr, struct inet_protocol *protocol)
{
  struct udphdr *th;
  struct sock *sk;
   
  PRINTK(("udp_err(err=%d, header=%X, daddr=%X, saddr=%X, inet_protocl=%X)\n"));

  th = (struct udphdr *)header;
  sk = get_sock(&udp_prot, ntohs(th->dest), saddr, th->source, daddr);

  if (sk == NULL) return;
  if (err & 0xff00 ==(ICMP_SOURCE_QUENCH << 8)) {
	if (sk->cong_window > 1) sk->cong_window = sk->cong_window/2;
	return;
  }

  sk->err = icmp_err_convert[err & 0xff].errno;

  /* It's only fatal if we have connected to them. */
  if (icmp_err_convert[err & 0xff].fatal && sk->state == TCP_ESTABLISHED) {
	sk->prot->close(sk, 0);
  }
}


static unsigned short
udp_check(struct udphdr *uh, int len,
	  unsigned long saddr, unsigned long daddr)
{
  unsigned long sum;

  PRINTK(("udp_check(uh=%X, len = %d, saddr = %X, daddr = %X)\n",
	   uh, len, saddr, daddr));

  print_uh(uh);

  __asm__("\t addl %%ecx,%%ebx\n"
	  "\t adcl %%edx,%%ebx\n"
	  "\t adcl $0, %%ebx\n"
	  : "=b"(sum)
	  : "0"(daddr), "c"(saddr), "d"((ntohs(len) << 16) + IPPROTO_UDP*256)
	  : "cx","bx","dx" );

  if (len > 3) {
	__asm__("\tclc\n"
		"1:\n"
		"\t lodsl\n"
		"\t adcl %%eax, %%ebx\n"
		"\t loop 1b\n"
		"\t adcl $0, %%ebx\n"
		: "=b"(sum) , "=S"(uh)
		: "0"(sum), "c"(len/4) ,"1"(uh)
		: "ax", "cx", "bx", "si" );
  }

  /* Convert from 32 bits to 16 bits. */
  __asm__("\t movl %%ebx, %%ecx\n"
	  "\t shrl $16,%%ecx\n"
	  "\t addw %%cx, %%bx\n"
	  "\t adcw $0, %%bx\n"
	  : "=b"(sum)
	  : "0"(sum)
	  : "bx", "cx");

  /* Check for an extra word. */
  if ((len & 2) != 0) {
	__asm__("\t lodsw\n"
		"\t addw %%ax,%%bx\n"
		"\t adcw $0, %%bx\n"
		: "=b"(sum), "=S"(uh)
		: "0"(sum) ,"1"(uh)
		: "si", "ax", "bx");
  }

  /* Now check for the extra byte. */
  if ((len & 1) != 0) {
	__asm__("\t lodsb\n"
		"\t movb $0,%%ah\n"
		"\t addw %%ax,%%bx\n"
		"\t adcw $0, %%bx\n"
		: "=b"(sum)
		: "0"(sum) ,"S"(uh)
		: "si", "ax", "bx");
  }

  /* We only want the bottom 16 bits, but we never cleared the top 16. */
  return((~sum) & 0xffff);
}


static void
udp_send_check(struct udphdr *uh, unsigned long saddr, 
	       unsigned long daddr, int len, struct sock *sk)
{
  uh->check = 0;
  if (sk && sk->no_check) return;
  uh->check = udp_check(uh, len, saddr, daddr);
}


static int
udp_loopback(struct sock *sk, unsigned short port, unsigned char *from,
	     int len, unsigned long daddr, unsigned long saddr)
{
  struct udphdr *uh;
  struct sk_buff *skb;
  struct sock *pair;

  sk->inuse = 1;
  PRINTK(("udp_loopback \n"));

  pair = get_sock(sk->prot, ntohs(port), saddr, sk->dummy_th.source, daddr);
  if (pair == NULL) return(0);

  skb = pair->prot->rmalloc(pair, sizeof(*skb) + sizeof(*uh) + len + 4,
				   				0, GFP_KERNEL);

  /* If we didn't get the memory, just drop the packet. */
  if (skb == NULL) return(len);
  skb->sk = pair;
  skb->lock = 0;
  skb->mem_addr = skb;
  skb->mem_len = sizeof(*skb) + len + sizeof(*uh) + 4;
  skb->daddr = saddr;
  skb->saddr = daddr;
  skb->len = len;
  skb->h.raw =(unsigned char *)(skb+1);

  uh = skb->h.uh;
  uh->source = sk->dummy_th.source;
  uh->dest = port;
  uh->len = len + sizeof(*uh);
  /* verify_area(VERIFY_WRITE, from , len); */
  memcpy_fromfs(uh+1, from, len);
  pair->inuse = 1;
  if (pair->rqueue == NULL) {
	pair->rqueue = skb;
	skb->next = skb;
	skb->prev = skb;
  } else {
	skb->next = pair->rqueue;
	skb->prev = pair->rqueue->prev;
	skb->prev->next = skb;
	skb->next->prev = skb;
  }
  wake_up(pair->sleep);
  release_sock(pair);
  release_sock(sk);
  return(len);
}


static int
udp_sendto(struct sock *sk, unsigned char *from, int len, int noblock,
	   unsigned flags, struct sockaddr_in *usin, int addr_len)
{
  struct sk_buff *skb;
  struct udphdr *uh;
  unsigned char *buff;
  unsigned long saddr;
  int copied=0;
  int amt;
  struct device *dev=NULL;
  struct sockaddr_in sin;

  /* Check the flags. */
  if (flags) return(-EINVAL);
  if (len < 0) return(-EINVAL);
  if (len == 0) return(0);

  PRINTK(("sendto len = %d\n", len));

  /* Get and verify the address. */
  if (usin) {
	if (addr_len < sizeof(sin)) return(-EINVAL);
	/* verify_area(VERIFY_WRITE, usin, sizeof(sin));*/
	memcpy_fromfs(&sin, usin, sizeof(sin));
	if (sin.sin_family && sin.sin_family != AF_INET) return(-EINVAL);
	if (sin.sin_port == 0) return(-EINVAL);
  } else {
	if (sk->state != TCP_ESTABLISHED) return(-EINVAL);
	sin.sin_family = AF_INET;
	sin.sin_port = sk->dummy_th.dest;
	sin.sin_addr.s_addr = sk->daddr;
  }

  /* Check for a valid saddr. */
  saddr = sk->saddr;
  if (chk_addr(saddr) == IS_BROADCAST) saddr = my_addr();

  /* If it's a broadcast, make sure we get it. */
  if (chk_addr(saddr) == IS_BROADCAST) {
	int err;

	err = udp_loopback(sk, sin.sin_port, from, len,
				sin.sin_addr.s_addr, saddr);
	if (err < 0) return(err);
  }
  sk->inuse = 1;

  while(len > 0) {
	int tmp;

	skb = sk->prot->wmalloc(sk, len + sizeof(*skb) +
				sk->prot->max_header, 0, GFP_KERNEL);

	/* This should never happen, but it is possible. */
	if (skb == NULL) {
		tmp = sk->wmem_alloc;
		release_sock(sk);
		if (copied) return(copied);
		if (noblock) return(-EAGAIN);
		cli();
		if (tmp <= sk->wmem_alloc) {
			interruptible_sleep_on(sk->sleep);
			if (current->signal & ~current->blocked) {
				sti();
				if (copied) return(copied);
				return(-ERESTARTSYS);
			}
		}
		sk->inuse = 1;
		sti();
		continue;
	}

	skb->lock = 0;
	skb->mem_addr = skb;
	skb->mem_len = len + sizeof(*skb) + sk->prot->max_header;
	skb->sk = sk;
	skb->free = 1;
	skb->arp = 0;

	/* Now build the IP and MAC header. */
	buff = (unsigned char *)(skb+1);
	tmp = sk->prot->build_header(skb, saddr,
				     sin.sin_addr.s_addr, &dev,
				     IPPROTO_UDP, sk->opt, skb->mem_len);
	if (tmp < 0 ) {
		sk->prot->wfree(sk, skb->mem_addr, skb->mem_len);
		release_sock(sk);
		return(tmp);
	}
	buff += tmp;

	/*
	 * We shouldn't do this, instead we should just
	 * let the IP protocol fragment the packet.
	 */
	amt = min(len + tmp + sizeof(*uh), dev->mtu);
	PRINTK(("amt = %d, dev = %X, dev->mtu = %d\n", amt, dev, dev->mtu));

	skb->len = amt;
	amt -= tmp; 

	uh =(struct udphdr *)buff;
	uh->len = ntohs(amt);
	uh->source = sk->dummy_th.source;
	uh->dest = sin.sin_port;

	amt -= sizeof(*uh);
	buff += sizeof(*uh);
	if (amt < 0) {
		printk("udp.c: amt = %d < 0\n",amt);
		release_sock(sk);
		return(copied);
	}

	/* verify_area(VERIFY_WRITE, from, amt);*/
	memcpy_fromfs( buff, from, amt);
	len -= amt;
	copied += amt;
	from += amt;
	udp_send_check(uh, saddr, sin.sin_addr.s_addr,
				  amt+sizeof(*uh), sk);
				  
	sk->prot->queue_xmit(sk, dev, skb, 1);
  }
  release_sock(sk);
  return(copied);
}


static int
udp_write(struct sock *sk, unsigned char *buff, int len, int noblock,
	  unsigned flags)
{
  return(udp_sendto(sk, buff, len, noblock, flags, NULL, 0));
}


int
udp_ioctl(struct sock *sk, int cmd, unsigned long arg)
{
	switch(cmd) {
		default:
			return(-EINVAL);

		case TIOCOUTQ:
			{
				unsigned long amount;

				if (sk->state == TCP_LISTEN) return(-EINVAL);
				amount = sk->prot->wspace(sk)/2;
				verify_area(VERIFY_WRITE,(void *)arg,
						sizeof(unsigned long));
				put_fs_long(amount,(unsigned long *)arg);
				return(0);
			}

		case TIOCINQ:
#if 0	/* FIXME: */
		case FIONREAD:
#endif
			{
				struct sk_buff *skb;
				unsigned long amount;

				if (sk->state == TCP_LISTEN) return(-EINVAL);
				amount = 0;
				skb = sk->rqueue;
				if (skb != NULL) {
					/*
					 * We will only return the amount
					 * of this packet since that is all
					 * that will be read.
					 */
					amount = skb->len;
				}
				verify_area(VERIFY_WRITE,(void *)arg,
						sizeof(unsigned long));
				put_fs_long(amount,(unsigned long *)arg);
				return(0);
			}
	}
}


/*
 * This should be easy, if there is something there we\
 * return it, otherwise we block.
 */
int
udp_recvfrom(struct sock *sk, unsigned char *to, int len,
	     int noblock, unsigned flags, struct sockaddr_in *sin,
	     int *addr_len)
{
  int copied=0;
  struct sk_buff *skb;

  if (len == 0) return(0);
  if (len < 0) return(-EINVAL);

  /*
   * This will pick up errors that occured while the program
   * was doing something else.
   */
  if (sk->err) {
	int err;

	err = -sk->err;
	sk->err = 0;
	return(err);
  }
  if (addr_len) {
	verify_area(VERIFY_WRITE, addr_len, sizeof(*addr_len));
	put_fs_long(sizeof(*sin), addr_len);
  }
  sk->inuse = 1;
  while(sk->rqueue == NULL) {
	if (sk->shutdown & RCV_SHUTDOWN) {
		return(0);
	}

	if (noblock) {
		release_sock(sk);
		return(-EAGAIN);
	}
	release_sock(sk);
	cli();
	if (sk->rqueue == NULL) {
		interruptible_sleep_on(sk->sleep);
		if (current->signal & ~current->blocked) {
			return(-ERESTARTSYS);
		}
	}
	sk->inuse = 1;
	sti();
  }
  skb = sk->rqueue;

  if (!(flags & MSG_PEEK)) {
	if (skb->next == skb) {
		sk->rqueue = NULL;
	} else {
		sk->rqueue =(struct sk_buff *)sk->rqueue ->next;
		skb->prev->next = skb->next;
		skb->next->prev = skb->prev;
	}
  }
  copied = min(len, skb->len);
  verify_area(VERIFY_WRITE, to, copied);
  memcpy_tofs(to, skb->h.raw + sizeof(struct udphdr), copied);

  /* Copy the address. */
  if (sin) {
	struct sockaddr_in addr;

	addr.sin_family = AF_INET;
	addr.sin_port = skb->h.uh->source;
	addr.sin_addr.s_addr = skb->daddr;
	verify_area(VERIFY_WRITE, sin, sizeof(*sin));
	memcpy_tofs(sin, &addr, sizeof(*sin));
  }

  if (!(flags & MSG_PEEK)) {
	kfree_skb(skb, FREE_READ);
  }
  release_sock(sk);
  return(copied);
}


int
udp_read(struct sock *sk, unsigned char *buff, int len, int noblock,
	 unsigned flags)
{
  return(udp_recvfrom(sk, buff, len, noblock, flags, NULL, NULL));
}


int
udp_connect(struct sock *sk, struct sockaddr_in *usin, int addr_len)
{
  struct sockaddr_in sin;

  if (addr_len < sizeof(sin)) return(-EINVAL);
  /* verify_area(VERIFY_WRITE, usin, sizeof(sin)); */
  memcpy_fromfs(&sin, usin, sizeof(sin));
  if (sin.sin_family && sin.sin_family != AF_INET) return(-EAFNOSUPPORT);
  sk->daddr = sin.sin_addr.s_addr;
  sk->dummy_th.dest = sin.sin_port;
  sk->state = TCP_ESTABLISHED;
  return(0);
}


static void
udp_close(struct sock *sk, int timeout)
{
  sk->inuse = 1;
  sk->state = TCP_CLOSE;
  if (sk->dead) destroy_sock(sk);
    else release_sock(sk);
}


/* All we need to do is get the socket, and then do a checksum. */
int
udp_rcv(struct sk_buff *skb, struct device *dev, struct options *opt,
	unsigned long daddr, unsigned short len,
	unsigned long saddr, int redo, struct inet_protocol *protocol)
{
  struct proto *prot=&udp_prot;
  struct sock *sk;
  struct udphdr *uh;

  uh = (struct udphdr *) skb->h.uh;

  /*
   * FIXME:	THIS.IS.A.MESS. !
   * If the packet was not broadcasted out from a machine
   * that is using BOOTP to find out its IP address(which
   * is indicated by the fact that 'saddr' is 0.0.0.0), add
   * the source address to the ARP cache. - FvK
   */
  if ((saddr != 0) && dev->add_arp) dev->add_arp(saddr, skb, dev);

  sk = get_sock(prot, ntohs(uh->dest), saddr, uh->source, daddr);
  if (sk == NULL) {
	if ((daddr & 0xff000000 != 0) && (daddr & 0xff000000 != 0xff000000)) {
		icmp_reply(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, dev);
	}
	skb->sk = NULL;
	kfree_skb(skb, 0);
	return(0);
  }

  if (!redo) {
	if (uh->check && udp_check(uh, len, saddr, daddr)) {
		PRINTK(("bad udp checksum\n"));
		skb->sk = NULL;
		kfree_skb(skb, 0);
		return(0);
	}

	skb->sk = sk;
	skb->dev = dev;
	skb->len = len;

	/* These are supposed to be switched. */
	skb->daddr = saddr;
	skb->saddr = daddr;

	/* Now deal with the in use. */
	cli();
	if (sk->inuse) {
		if (sk->back_log == NULL) {
			sk->back_log = skb;
			skb->next = skb;
			skb->prev = skb;
		} else {
			skb->next = sk->back_log;
			skb->prev = sk->back_log->prev;
			skb->prev->next = skb;
			skb->next->prev = skb;
		}
		sti();
		return(0);
	}
	sk->inuse = 1;
	sti();
  }

  /* Charge it to the socket. */
  if (sk->rmem_alloc + skb->mem_len >= SK_RMEM_MAX) {
	skb->sk = NULL;
	kfree_skb(skb, 0);
	release_sock(sk);
	return(0);
  }
  sk->rmem_alloc += skb->mem_len;

  /* At this point we should print the thing out. */
  PRINTK(("<< \n"));

  /* Now add it to the data chain and wake things up. */
  if (sk->rqueue == NULL) {
	sk->rqueue = skb;
	skb->next = skb;
	skb->prev = skb;
  } else {
	skb->next = sk->rqueue;
	skb->prev = sk->rqueue->prev;
	skb->prev->next = skb;
	skb->next->prev = skb;
  }
  skb->len = len - sizeof(*uh);

  if (!sk->dead) wake_up(sk->sleep);

  release_sock(sk);
  return(0);
}


struct proto udp_prot = {
  sock_wmalloc,
  sock_rmalloc,
  sock_wfree,
  sock_rfree,
  sock_rspace,
  sock_wspace,
  udp_close,
  udp_read,
  udp_write,
  udp_sendto,
  udp_recvfrom,
  ip_build_header,
  udp_connect,
  NULL,
  ip_queue_xmit,
  ip_retransmit,
  NULL,
  NULL,
  udp_rcv,
  udp_select,
  udp_ioctl,
  NULL,
  NULL,
  128,
  0,
  {NULL,},
  "UDP"
};
