/* Connection tracking via netlink socket. Allows for user space
 * protocol helpers and general trouble making from userspace.
 *
 * (C) 2001 by Jay Schulist <jschlst@samba.org>
 * (C) 2002 by Harald Welte <laforge@gnumonks.org>
 * (C) 2003 by Patrick Mchardy <kaber@trash.net>,
 *             Harald Welte <laforge@gnumonks.org>
 *
 * Initial connection tracking via netlink development funded and 
 * generally made possible by Network Robots, Inc. (www.networkrobots.com)
 *
 * Further development of this code funded by Astaro AG (http://www.astaro.com)
 *
 * This software may be used and distributed according to the terms
 * of the GNU General Public License, incorporated herein by reference.
 */

#include <linux/config.h>
#include <linux/module.h>
#include <linux/types.h>
#include <linux/socket.h>
#include <linux/kernel.h>
#include <linux/major.h>
#include <linux/sched.h>
#include <linux/timer.h>
#include <linux/string.h>
#include <linux/sockios.h>
#include <linux/net.h>
#include <linux/fcntl.h>
#include <linux/skbuff.h>
#include <asm/uaccess.h>
#include <asm/system.h>
#include <net/sock.h>
#include <linux/init.h>
#include <linux/netlink.h>
#include <linux/spinlock.h>
#include <linux/notifier.h>
#include <linux/rtnetlink.h>

#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>
#include <linux/netfilter_ipv4/ip_tables.h>
#include <linux/netfilter_ipv4/ip_conntrack.h>
#include <linux/netfilter_ipv4/ip_conntrack_core.h>
#include <linux/netfilter_ipv4/ip_conntrack_helper.h>
#include <linux/netfilter_ipv4/ip_conntrack_protocol.h>

#include <linux/nfnetlink.h>
#include <linux/nfnetlink_conntrack.h>

#define ASSERT_READ_LOCK(x) MUST_BE_READ_LOCKED(&ip_conntrack_lock)
#define ASSERT_WRITE_LOCK(x) MUST_BE_WRITE_LOCKED(&ip_conntrack_lock)
#include <linux/netfilter_ipv4/listhelp.h>

MODULE_LICENSE("GPL");

static char __initdata ctversion[] = "0.12";

#if 1
static int ct_debug_level = 1;
#define ct_debug(level, format, arg...)					\
do {									\
	if(ct_debug_level > level)					\
		printk(KERN_DEBUG "%s: " format, __FUNCTION__, ## arg);	\
} while(0)
/* FIXME: this define is just needed for DUMP_TUPLE */
#define DEBUGP(format, args...)	ct_debug(0, format, ## args)
#else
#define ct_debug(level, format, arg...)
#define DEBUGP(format, args...)
#endif

static struct nfnetlink_subsystem *ctnl_subsys;


static inline int
ctnetlink_dump_tuples(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	NFA_PUT(skb, CTA_ORIG, sizeof(struct ip_conntrack_tuple),
	        &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple);
	NFA_PUT(skb, CTA_RPLY, sizeof(struct ip_conntrack_tuple),
	        &ct->tuplehash[IP_CT_DIR_REPLY].tuple);
	return 0;

nfattr_failure:
	return -1;
}

static inline int
ctnetlink_dump_status(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	NFA_PUT(skb, CTA_STATUS, sizeof(ct->status), &ct->status);
	return 0;

nfattr_failure:
	return -1;
}

static inline int
ctnetlink_dump_timeout(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	unsigned long timeout = (ct->timeout.expires - jiffies) / HZ;
	
	NFA_PUT(skb, CTA_TIMEOUT, sizeof(timeout), &timeout);
	return 0;

nfattr_failure:
	return -1;
}

static inline int
ctnetlink_dump_protoinfo(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	struct cta_proto cp;

	cp.num_proto = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.protonum;
	memcpy(&cp.proto, &ct->proto, sizeof(cp.proto));
	NFA_PUT(skb, CTA_PROTOINFO, sizeof(cp), &cp);
	return 0;

nfattr_failure:
	return -1;
}

static inline int
ctnetlink_dump_helpinfo(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	struct ip_conntrack_helper *h = ct->helper;
	struct cta_help ch;

	if (h == NULL)
		memset(&ch, 0, sizeof(struct cta_help));
	else {
		strncpy((char *)&ch.name, h->name, sizeof(ch.name));
		memcpy(&ch.help, &ct->help, sizeof(ch.help));
	}
	NFA_PUT(skb, CTA_HELPINFO, sizeof(ch), &ch);
	return 0;

nfattr_failure:
	return -1;
}

static inline int
ctnetlink_dump_natinfo(struct sk_buff *skb, const struct ip_conntrack *ct)
{
#ifdef CONFIG_IP_NF_NAT_NEEDED
	const struct ip_nat_info *info = &ct->nat.info;
	struct cta_nat cn;

	if (!info->initialized || !info->num_manips)
		return 0;

	cn.num_manips = info->num_manips;
	memcpy(&cn.manips, &info->manips,
	       info->num_manips * sizeof(struct ip_nat_info_manip));
	NFA_PUT(skb, CTA_NATINFO, sizeof(struct cta_nat), &cn);
	return 0;

nfattr_failure:
	return -1;
#else
	return 0;
#endif
}

static inline int
ctnetlink_dump_mark(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	return 0;
}

static int
ctnetlink_fill_info(struct sk_buff *skb, u32 pid, u32 seq,
		    int event, int nowait, 
		    const struct ip_conntrack *ct)
{
	struct nlmsghdr *nlh;
	struct nfgenmsg *nfmsg;
	unsigned char *b;

	b = skb->tail;

	event |= NFNL_SUBSYS_CTNETLINK << 8;
	nlh    = NLMSG_PUT(skb, pid, seq, event, sizeof(struct nfgenmsg));
	nfmsg  = NLMSG_DATA(nlh);

	nlh->nlmsg_flags    = (nowait && pid) ? NLM_F_MULTI : 0;
	nfmsg->nfgen_family = AF_INET;

	if (ctnetlink_dump_tuples(skb, ct) < 0 ||
	    ctnetlink_dump_status(skb, ct) < 0 ||
	    ctnetlink_dump_timeout(skb, ct) < 0 ||
	    ctnetlink_dump_protoinfo(skb, ct) < 0 ||
	    ctnetlink_dump_helpinfo(skb, ct) < 0 ||
	    ctnetlink_dump_natinfo(skb, ct) < 0 ||
	    ctnetlink_dump_mark(skb, ct) < 0)
		goto nfattr_failure;

	nlh->nlmsg_len = skb->tail - b;
	return skb->len;

nlmsg_failure:
nfattr_failure:
	skb_trim(skb, b - skb->data);
	return -1;
}

static inline unsigned int
ctnetlink_get_mcgroups(struct ip_conntrack *ct)
{
	unsigned int groups;
	int proto = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.protonum;

	switch (proto) {
	case IPPROTO_TCP:
		groups = NFGRP_IPV4_CT_TCP;
		break;
	case IPPROTO_UDP:
		groups = NFGRP_IPV4_CT_UDP;
		break;
	case IPPROTO_ICMP:
		groups = NFGRP_IPV4_CT_ICMP;
		break;
	default:
		groups = NFGRP_IPV4_CT_OTHER;
		break;
	}

	return groups;
}

#define EVENT(m,e) ((m) & (1 << (e)))

static int ctnetlink_conntrack_event(struct notifier_block *this,
                                     unsigned long events, void *ptr)
{
	struct nlmsghdr *nlh;
	struct nfgenmsg *nfmsg;
	struct ip_conntrack *ct = (struct ip_conntrack *)ptr;
	struct sk_buff *skb;
	unsigned int type;
	unsigned char *b;
	int flags = 0;

	/* FIXME: much too big, costs lots of socket buffer space */
	skb = alloc_skb(400 /* NLMSG_GOODSIZE */, GFP_ATOMIC);
	if (!skb)
		return NOTIFY_DONE;

	if (EVENT(events, IPCT_DESTROY))
		type = CTNL_MSG_DELCONNTRACK;
	else {
		type = CTNL_MSG_NEWCONNTRACK;
		if (EVENT(events, IPCT_NEW)) {
			flags = NLM_F_CREATE|NLM_F_EXCL;
			/* dump everything */
			events = ~0UL;
		}
	}

	b = skb->tail;

	type |= NFNL_SUBSYS_CTNETLINK << 8;
	nlh   = NLMSG_PUT(skb, 0, 0, type, sizeof(struct nfgenmsg));
	nfmsg = NLMSG_DATA(nlh);

	nlh->nlmsg_flags    = flags;
	nfmsg->nfgen_family = AF_INET;

	if (ctnetlink_dump_tuples(skb, ct) < 0)
		goto nfattr_failure;

	if (EVENT(events, IPCT_STATUS)
	    && ctnetlink_dump_status(skb, ct) < 0)
		goto nfattr_failure;
	if (EVENT(events, IPCT_REFRESH)
	    && ctnetlink_dump_timeout(skb, ct) < 0)
		goto nfattr_failure;
	if (EVENT(events, IPCT_PROTOINFO)
	    && ctnetlink_dump_protoinfo(skb, ct) < 0)
		goto nfattr_failure;
	if (EVENT(events, IPCT_HELPINFO)
	    && ctnetlink_dump_helpinfo(skb, ct) < 0)
		goto nfattr_failure;
	if (EVENT(events, IPCT_NATINFO)
	    && ctnetlink_dump_natinfo(skb, ct) < 0)
		goto nfattr_failure;

	nlh->nlmsg_len = skb->tail - b;
	nfnetlink_send(skb, 0, ctnetlink_get_mcgroups(ct), 0);
	return NOTIFY_DONE;

nlmsg_failure:
nfattr_failure:
	kfree_skb(skb);
	return NOTIFY_DONE;
}

static const int cta_min[CTA_MAX] = {
	[CTA_ORIG-1]		= sizeof(struct ip_conntrack_tuple),
	[CTA_RPLY-1]		= sizeof(struct ip_conntrack_tuple),
	[CTA_STATUS-1]		= sizeof(unsigned long),
	[CTA_PROTOINFO-1]	= sizeof(struct cta_proto),
	[CTA_HELPINFO-1]	= sizeof(struct cta_help),
	[CTA_NATINFO-1]		= sizeof(struct cta_nat),
	[CTA_TIMEOUT-1]		= sizeof(unsigned long),

	[CTA_EXP_TUPLE-1]	= sizeof(struct ip_conntrack_tuple),
	[CTA_EXP_MASK-1]	= sizeof(struct ip_conntrack_tuple),
	[CTA_EXP_SEQNO-1]	= sizeof(u_int32_t),
	[CTA_EXP_PROTO-1]	= sizeof(struct cta_exp_proto),
	[CTA_EXP_HELP-1]	= sizeof(struct cta_exp_help),
	[CTA_EXP_TIMEOUT-1]	= sizeof(unsigned long)
};

static inline int ctnetlink_kill(const struct ip_conntrack *i, void *data)
{
	struct ip_conntrack *t = (struct ip_conntrack *)data;

	if (!memcmp(&i->tuplehash[IP_CT_DIR_ORIGINAL], 
	            &t->tuplehash[IP_CT_DIR_ORIGINAL], 
	            sizeof(struct ip_conntrack_tuple_hash))) {
		ip_conntrack_put(t);
		return 1;
	}

	return 0;
}

static int
ctnetlink_del_conntrack(struct sock *ctnl, struct sk_buff *skb, 
			struct nlmsghdr *nlh, int *errp)
{
	struct ip_conntrack_tuple_hash *h;
	struct ip_conntrack_tuple *tuple;
	struct nfattr *cda[CTA_MAX];

	ct_debug(0, "entered\n");

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	if (cda[CTA_ORIG-1] &&
	    NFA_PAYLOAD(cda[CTA_ORIG-1]) < cta_min[CTA_ORIG-1])
		return -EINVAL;

	if (cda[CTA_RPLY-1] &&
	    NFA_PAYLOAD(cda[CTA_RPLY-1]) < cta_min[CTA_RPLY-1])
		return -EINVAL;
	
	if (cda[CTA_ORIG-1])
		tuple = NFA_DATA(cda[CTA_ORIG-1]);
	else {
		if (cda[CTA_RPLY-1])
			tuple = NFA_DATA(cda[CTA_RPLY-1]);
		else {
			ct_debug(0, "no tuple found in request\n");
			return -EINVAL;
		}
	}

	h = ip_conntrack_find_get(tuple, NULL);
	if (!h) {
		ct_debug(0, "tuple not found in conntrack hash:");
		DUMP_TUPLE(tuple);
		return -ENOENT;
	}

	ct_debug(0, "calling selective_cleanup\n");
	ip_ct_selective_cleanup(ctnetlink_kill, h->ctrack);

	return 0;
}

static int ctnetlink_done(struct netlink_callback *cb)
{
	ct_debug(0, "entering\n");
	return 0;
}

static int
ctnetlink_dump_table(struct sk_buff *skb, struct netlink_callback *cb)
{
	struct ip_conntrack *ct;

	ct_debug(0, "entered, last=%lu\n", cb->args[0]);

	/* Traverse ordered list; send originals then reply. */
	READ_LOCK(&ip_conntrack_lock);
	list_for_each_entry(ct, &ip_conntrack_ordered_list, olist) {
		if (ct->id <= cb->args[0])
			continue;
		if (ctnetlink_fill_info(skb, NETLINK_CB(cb->skb).pid,
		                        cb->nlh->nlmsg_seq,
					CTNL_MSG_NEWCONNTRACK, 1, ct) < 0)
			break;
		cb->args[0] = ct->id;
	}
	READ_UNLOCK(&ip_conntrack_lock);
	
	ct_debug(0, "leaving, last=%lu\n", cb->args[0]);

	return skb->len;
}

static int
ctnetlink_get_conntrack(struct sock *ctnl, struct sk_buff *skb, 
			struct nlmsghdr *nlh, int *errp)
{
	struct ip_conntrack_tuple_hash *h;
	struct ip_conntrack_tuple *tuple;
	struct nfattr *cda[CTA_MAX];
	struct ip_conntrack *ct;
	struct sk_buff *skb2 = NULL;
	int err;

	ct_debug(0, "entered\n");

	if (nlh->nlmsg_flags & NLM_F_DUMP) {
		struct nfgenmsg *msg = NLMSG_DATA(nlh);
		u32 rlen;

		if (msg->nfgen_family != AF_INET)
			return -EAFNOSUPPORT;

		if ((*errp = netlink_dump_start(ctnl, skb, nlh,
		                                ctnetlink_dump_table,
		                                ctnetlink_done)) != 0)
			return -EINVAL;

		rlen = NLMSG_ALIGN(nlh->nlmsg_len);
		if (rlen > skb->len)
			rlen = skb->len;
		skb_pull(skb, rlen);
		return 0;
	}

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	if (cda[CTA_ORIG-1] &&
	    NFA_PAYLOAD(cda[CTA_ORIG-1]) < cta_min[CTA_ORIG-1])
		return -EINVAL;

	if (cda[CTA_RPLY-1] &&
	    NFA_PAYLOAD(cda[CTA_RPLY-1]) < cta_min[CTA_RPLY-1])
		return -EINVAL;
	
	if (cda[CTA_ORIG-1])
		tuple = NFA_DATA(cda[CTA_ORIG-1]);
	else {
		if (cda[CTA_RPLY-1])
			tuple = NFA_DATA(cda[CTA_RPLY-1]);
		else
			return -EINVAL;
	}

	h = ip_conntrack_find_get(tuple, NULL);
	if (!h) {
		ct_debug(0, "tuple not found in conntrack hash:");
		DUMP_TUPLE(tuple);
		return -ENOENT;
	}
	ct = h->ctrack;

	skb2 = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC);
	if (!skb2) {
		ip_conntrack_put(ct);
		return -ENOMEM;
	}
	NETLINK_CB(skb2).dst_pid = NETLINK_CB(skb).pid;

	err = ctnetlink_fill_info(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 
				  CTNL_MSG_NEWCONNTRACK, 1, ct);
	ip_conntrack_put(ct);
	if (err <= 0)
		goto nlmsg_failure;

	err = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
	if (err < 0)
		return err;
	return 0;

nlmsg_failure:
	if (skb2)
		kfree_skb(skb2);
	return -1;
}

static inline int
ctnetlink_change_status(struct ip_conntrack *ct, unsigned long *status)
{
	unsigned long d = ct->status ^ *status;

	if (d & (IPS_EXPECTED|IPS_CONFIRMED|IPS_DESTROYED))
		/* unchangeable */
		return -EINVAL;
	
	if (d & IPS_SEEN_REPLY && !(*status & IPS_SEEN_REPLY))
		/* SEEN_REPLY bit can only be set */
		return -EINVAL;

	if (d & IPS_ASSURED && !(*status & IPS_ASSURED))
		/* ASSURED bit can only be set */
		return -EINVAL;

	ct->status = *status;
	return 0;
}

static inline int
ctnetlink_change_protoinfo(struct ip_conntrack *ct, struct cta_proto *cp)
{
	struct ip_conntrack_protocol *icp;
	int proto = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.protonum;
	
	if (cp->num_proto != proto)
		return -EINVAL;

	icp = __ip_ct_find_proto(cp->num_proto);
	if (icp->ctnl_check_private
	    && icp->ctnl_check_private(&cp->proto) < 0)
		return -EINVAL;

	if (icp->ctnl_change)
		icp->ctnl_change(ct, &cp->proto);

	return 0;
}

static inline int
ctnetlink_change_helpinfo(struct ip_conntrack *ct, struct cta_help *h)
{
	struct ip_conntrack_helper *helper = ct->helper;
	struct ip_conntrack_tuple *reply;

	if (helper == NULL) {
		if (*h->name == '\0')
			return 0;
		if (ct->master)
			return -EINVAL;
		reply = &ct->tuplehash[IP_CT_DIR_REPLY].tuple;
		helper = ip_ct_find_helper(reply);
		if (helper == NULL)
			return -ENOENT;
	} else if (*h->name == '\0') {
		ip_conntrack_remove_expectations(ct, 1);
		ct->helper = NULL;
		return 0;
	}

	h->name[CTA_HELP_MAXNAMESZ - 1] = '\0';
	if (strcmp(helper->name, h->name))
		return -EINVAL;

	ct->helper = helper;
	if (helper->ctnl_change)
		helper->ctnl_change(ct, &h->help);

	return 0;
}

static inline int
ctnetlink_change_natinfo(struct ip_conntrack *ct, struct cta_nat *n)
{
#ifdef CONFIG_IP_NF_NAT_NEEDED
	struct ip_nat_info *info = &ct->nat.info;
	int i;

	if (n->num_manips > IP_NAT_MAX_MANIPS)
		return -EINVAL;

	if (info->initialized && n->num_manips < info->num_manips)
		return -EINVAL;

	for (i = 0; i < n->num_manips; i++) {
		if (n->manips[i].direction > IP_CT_DIR_MAX)
			return -EINVAL;
		if (n->manips[i].hooknum > NF_IP_NUMHOOKS)
			return -EINVAL;
		if (n->manips[i].hooknum == NF_IP_FORWARD)
			return -EINVAL;
		if (n->manips[i].maniptype > IP_NAT_MANIP_DST)
			return -EINVAL;
	}

	return 0;
#else
	return -EOPNOTSUPP;
#endif
}

static inline int
ctnetlink_change_timeout(struct ip_conntrack *ct, unsigned long *timeout)
{
	if (!del_timer(&ct->timeout))
		return -ETIME;
	ct->timeout.expires = jiffies + *timeout * HZ;
	add_timer(&ct->timeout);

	return 0;
}

static int
ctnetlink_change_conntrack(struct ip_conntrack *ct, struct nfattr *cda[])
{
	void *data;
	int err;
	
	ct_debug(0, "entered\n");

	if (cda[CTA_STATUS-1]) {
		data = NFA_DATA(cda[CTA_STATUS-1]);
		if ((err = ctnetlink_change_status(ct, data)) < 0)
			return err;
	}
	if (cda[CTA_PROTOINFO-1]) {
		data = NFA_DATA(cda[CTA_PROTOINFO-1]);
		if ((err = ctnetlink_change_protoinfo(ct, data)) < 0)
			return err;
	}
	if (cda[CTA_HELPINFO-1]) {
		data = NFA_DATA(cda[CTA_HELPINFO-1]);
		if ((err = ctnetlink_change_helpinfo(ct, data)) < 0)
			return err;
	}
	if (cda[CTA_NATINFO-1]) {
		data = NFA_DATA(cda[CTA_NATINFO-1]);
		if ((err = ctnetlink_change_natinfo(ct, data)) < 0)
			return err;
	}
	if (cda[CTA_TIMEOUT-1]) {
		data = NFA_DATA(cda[CTA_TIMEOUT-1]);
		if ((err = ctnetlink_change_timeout(ct, data)) < 0)
			return err;
	}

	ct_debug(0, "all done\n");
	return 0;
}

static int
ctnetlink_create_conntrack(struct nfattr *cda[])
{
	struct ip_conntrack *ct;
	struct ip_conntrack_tuple *otuple, *rtuple, t;
	struct ip_conntrack_protocol *icp;
	struct cta_proto *proto;
	unsigned long *status;
	unsigned long *timeout;
	int err;

	ct_debug(0, "entered\n");

	if (!(cda[CTA_ORIG-1] && cda[CTA_RPLY-1] && cda[CTA_STATUS-1] &&
	      cda[CTA_PROTOINFO-1] && cda[CTA_TIMEOUT-1])) {
		ct_debug(0, "required attribute(s) missing\n");
		return -EINVAL;
	}

	otuple  = NFA_DATA(cda[CTA_ORIG-1]);
	rtuple  = NFA_DATA(cda[CTA_RPLY-1]);
	timeout = NFA_DATA(cda[CTA_TIMEOUT-1]);

	status  = NFA_DATA(cda[CTA_STATUS-1]);
	if (!(*status & IPS_CONFIRMED))
		return -EINVAL;	/* cannot create unconfirmed connections */

	proto = NFA_DATA(cda[CTA_PROTOINFO-1]);
	icp   = __ip_ct_find_proto(proto->num_proto);

	if (!invert_tuple(&t, otuple, icp) || !ip_ct_tuple_equal(&t, rtuple))
		; // FIXME: nat changes reply tuples // return -EINVAL;

	if (icp->ctnl_check_tuples
	    && icp->ctnl_check_tuples(otuple, rtuple) < 0)
		return -EINVAL;
	
	if (icp->ctnl_check_private
	    && icp->ctnl_check_private(&proto->proto) < 0)
		return -EINVAL;

	ct = ip_conntrack_alloc(otuple, rtuple);
	if (ct == NULL)
		return -ENOMEM;

	ct->status = *status;
	ct->timeout.expires = jiffies + *timeout * HZ;

	if (icp->ctnl_change)
		icp->ctnl_change(ct, &proto->proto);

	cda[CTA_ORIG-1] = cda[CTA_RPLY-1] = cda[CTA_PROTOINFO-1] = 
		cda[CTA_STATUS-1] = cda[CTA_TIMEOUT-1] = NULL;

	err = ctnetlink_change_conntrack(ct, cda);
	if (err < 0) {
		ip_conntrack_free(ct);
		return err;
	}

	ip_conntrack_place_in_lists(ct);
	add_timer(&ct->timeout);

	ct_debug(0, "all done\n");
	return 0;
}

static int 
ctnetlink_new_conntrack(struct sock *ctnl, struct sk_buff *skb, 
			struct nlmsghdr *nlh, int *errp)
{
	struct nfattr *cda[CTA_MAX];
	struct ip_conntrack_tuple *otuple = NULL, *rtuple = NULL;
	struct ip_conntrack_tuple_hash *h = NULL;
	int i, err = 0;

	ct_debug(0, "entered\n");

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	for (i = 0; i < CTA_MAX; i++)
		if (cda[i] && NFA_PAYLOAD(cda[i]) < cta_min[i])
			return -EINVAL;

	ct_debug(0, "all attribute sizes ok\n");

	if (cda[CTA_ORIG-1])
		otuple = NFA_DATA(cda[CTA_ORIG-1]);
	
	if (cda[CTA_RPLY-1])
		rtuple = NFA_DATA(cda[CTA_RPLY-1]);

	if (otuple == NULL && rtuple == NULL) {
		ct_debug(0, "no tuple found in request\n");
		return -EINVAL;
	}

	WRITE_LOCK(&ip_conntrack_lock);
	if (otuple)
		h = __ip_conntrack_find_get(otuple, NULL);
	if (h == NULL && rtuple)
		h = __ip_conntrack_find_get(rtuple, NULL);

	if (h == NULL) {
		ct_debug(0, "no such conntrack, create new\n");
		err = -ENOENT;
		if (!(nlh->nlmsg_flags & NLM_F_CREATE))
			goto out_unlock;
		err = ctnetlink_create_conntrack(cda);
		goto out_unlock;
	} else {
		ct_debug(0, "conntrack found, change\n");
		err = -EEXIST;
		if (nlh->nlmsg_flags & NLM_F_EXCL)
			goto out_put;
		err = ctnetlink_change_conntrack(h->ctrack, cda);
	}

out_put:
	ip_conntrack_put(h->ctrack);
out_unlock:
	WRITE_UNLOCK(&ip_conntrack_lock);
	return err;
}

/* EXPECT */

static inline int
ctnetlink_exp_dump_tuples(struct sk_buff *skb,
                          const struct ip_conntrack_expect *exp)
{
	NFA_PUT(skb, CTA_EXP_TUPLE, sizeof(struct ip_conntrack_tuple),
	        &exp->tuple);
	NFA_PUT(skb, CTA_EXP_MASK, sizeof(struct ip_conntrack_tuple),
		&exp->mask);
	return 0;
	
nfattr_failure:
	return -1;
}

static inline int
ctnetlink_exp_dump_seqno(struct sk_buff *skb,
                         const struct ip_conntrack_expect *exp)
{
	NFA_PUT(skb, CTA_EXP_SEQNO, sizeof(u_int32_t), &exp->seq);
	return 0;
	
nfattr_failure:
	return -1;
}

static inline int
ctnetlink_exp_dump_proto(struct sk_buff *skb,
                         const struct ip_conntrack_expect *exp)
{
	return 0;
}

static inline int
ctnetlink_exp_dump_help(struct sk_buff *skb,
                        const struct ip_conntrack_expect *exp)
{
	struct cta_exp_help ch;

	memcpy(&ch.help, &exp->help, sizeof(ch.help));
	NFA_PUT(skb, CTA_EXP_HELP, sizeof(union ip_conntrack_expect_help),
	        &exp->help);
	return 0;
	
nfattr_failure:
	return -1;
}

static int
ctnetlink_exp_fill_info(struct sk_buff *skb, u32 pid, u32 seq,
		    int event, 
		    int nowait, 
		    const struct ip_conntrack_expect *exp)
{
	struct nlmsghdr *nlh;
	struct nfgenmsg *nfmsg;
	unsigned char *b;

	b = skb->tail;

	event |= NFNL_SUBSYS_CTNETLINK << 8;
	nlh    = NLMSG_PUT(skb, pid, seq, event, sizeof(struct nfgenmsg));
	nfmsg  = NLMSG_DATA(nlh);

	nlh->nlmsg_flags    = (nowait && pid) ? NLM_F_MULTI : 0;
	nfmsg->nfgen_family = AF_INET;

	if (ctnetlink_exp_dump_tuples(skb, exp) < 0 ||
	    ctnetlink_exp_dump_seqno(skb, exp) < 0 ||
	    ctnetlink_exp_dump_proto(skb, exp) < 0 ||
	    ctnetlink_exp_dump_help(skb, exp) < 0)
		goto nfattr_failure;

	nlh->nlmsg_len = skb->tail - b;
	return skb->len;

nlmsg_failure:
nfattr_failure:
	skb_trim(skb, b - skb->data);
	return -1;
}

static inline struct sk_buff *
ctnetlink_exp_event_build_msg(const struct ip_conntrack_expect *exp)
{
	struct sk_buff *skb;
	int err;

	skb = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC);
	if (!skb)
		return NULL;

	err = ctnetlink_exp_fill_info(skb, 0, 0, CTNL_MSG_NEWEXPECT, 1, exp);
	if (err <= 0)
		goto nlmsg_failure;
	return skb;

nlmsg_failure:
	if (skb)
		kfree_skb(skb);
	return NULL;
}

static void
ctnetlink_exp_create(struct ip_conntrack_expect *exp)
{
	u16 proto = exp->tuple.dst.protonum;
	struct sk_buff *skb;

	skb = ctnetlink_exp_event_build_msg(exp);
	if (!skb)
		return;

	if (proto == IPPROTO_TCP) {
		nfnetlink_send(skb, 0, NFGRP_IPV4_CT_TCP, 0);
		return;
	} else if (proto == IPPROTO_UDP) {
		nfnetlink_send(skb, 0, NFGRP_IPV4_CT_UDP, 0);
		return;
	} else if (proto == IPPROTO_ICMP) {
		nfnetlink_send(skb, 0, NFGRP_IPV4_CT_ICMP, 0);
		return;
	} else {
		nfnetlink_send(skb, 0, NFGRP_IPV4_CT_OTHER, 0);
		return;
	}
	kfree_skb(skb);
	return;
}

static int
ctnetlink_del_expect(struct sock *ctnl, struct sk_buff *skb, 
		     struct nlmsghdr *nlh, int *errp)
{
	struct ip_conntrack_expect *exp;
	struct ip_conntrack_tuple *tuple;
	struct nfattr *cda[CTA_MAX];

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	if (cda[CTA_ORIG-1] &&
	    NFA_PAYLOAD(cda[CTA_ORIG-1]) < cta_min[CTA_ORIG-1])
		return -EINVAL;

	if (cda[CTA_RPLY-1] &&
	    NFA_PAYLOAD(cda[CTA_RPLY-1]) < cta_min[CTA_RPLY-1])
		return -EINVAL;

	if (cda[CTA_ORIG-1])
		tuple = NFA_DATA(cda[CTA_ORIG-1]);
	else {
		if (cda[CTA_RPLY-1])
			tuple = NFA_DATA(cda[CTA_RPLY-1]);
		else
			return -EINVAL;
	}

	/* bump usage count to 2 */
	exp = ip_conntrack_expect_find_get(tuple);
	if (!exp)
		return -ENOENT;

	/* after list removal, usage count == 1 */
	ip_conntrack_unexpect_related(exp);
	/* we have put what we 'get' above. after this line usage count == 0 */
	ip_conntrack_expect_put(exp);

	return 0;
}

static int
ctnetlink_exp_dump_build_msg(const struct ip_conntrack_expect *exp,
			 struct sk_buff *skb, u32 pid, u32 seq)
{
	int err, proto;

	proto = exp->tuple.dst.protonum;
	err = ctnetlink_exp_fill_info(skb, pid, seq, CTNL_MSG_NEWEXPECT, 1, 
				      exp);
	if (err <= 0)
		goto nlmsg_failure;
	return 0;

nlmsg_failure:
	if (skb)
		kfree_skb(skb);
	return -1;
}

static int
ctnetlink_exp_dump_table(struct sk_buff *skb, struct netlink_callback *cb)
{
	ct_debug(0, "entered\n");
	if (cb->args[0] == 0) {
		READ_LOCK(&ip_conntrack_lock);
		LIST_FIND(&ip_conntrack_expect_list, 
			  ctnetlink_exp_dump_build_msg,
			  struct ip_conntrack_expect *, skb,
			  NETLINK_CB(cb->skb).pid, cb->nlh->nlmsg_seq);
		READ_UNLOCK(&ip_conntrack_lock);
		cb->args[0] = 1;
	}
	ct_debug(0, "returning\n");

	return skb->len;
}


static int
ctnetlink_get_expect(struct sock *ctnl, struct sk_buff *skb, 
		     struct nlmsghdr *nlh, int *errp)
{
	struct ip_conntrack_expect *exp;
	struct ip_conntrack_tuple *tuple;
	struct nfattr *cda[CTA_MAX];
	struct sk_buff *skb2 = NULL;
	int err, proto;

	ct_debug(0, "entered\n");

	if (nlh->nlmsg_flags & NLM_F_DUMP) {
		struct nfgenmsg *msg = NLMSG_DATA(nlh);
		u32 rlen;

		if (msg->nfgen_family != AF_INET)
			return -EAFNOSUPPORT;

		ct_debug(0, "starting dump\n");
			if ((*errp = netlink_dump_start(ctnl, skb, nlh,
		    				ctnetlink_exp_dump_table,
						ctnetlink_done)) != 0)
			return -EINVAL;
		rlen = NLMSG_ALIGN(nlh->nlmsg_len);
		if (rlen > skb->len)
			rlen = skb->len;
		skb_pull(skb, rlen);
		return 0;
	}

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	if (cda[CTA_ORIG-1]
	    && NFA_PAYLOAD(cda[CTA_ORIG-1]) < sizeof(struct ip_conntrack_tuple))
		return -EINVAL;

	if (cda[CTA_RPLY-1]
	    && NFA_PAYLOAD(cda[CTA_RPLY-1]) < sizeof(struct ip_conntrack_tuple))
		return -EINVAL;

	if (cda[CTA_ORIG-1])
		tuple = NFA_DATA(cda[CTA_ORIG-1]);
	else {
		if (cda[CTA_RPLY-1])
			tuple = NFA_DATA(cda[CTA_RPLY-1]);
		else
			return -EINVAL;
	}

	exp = ip_conntrack_expect_find_get(tuple);
	if (!exp)
		return -ENOENT;

	skb2 = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC);
	if (!skb2)
		return -ENOMEM;
	NETLINK_CB(skb2).dst_pid = NETLINK_CB(skb).pid;
	proto = exp->tuple.dst.protonum;
	
	err = ctnetlink_exp_fill_info(skb2, NETLINK_CB(skb).pid, 
				      nlh->nlmsg_seq, CTNL_MSG_NEWEXPECT,
				      1, exp);
	if (err <= 0)
		goto nlmsg_failure;

	err = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
	if (err < 0)
		return err;
	return 0;

nlmsg_failure:
	if (skb2)
		kfree_skb(skb2);
	return -1;
}

static int
ctnetlink_change_expect(struct ip_conntrack_expect *x, struct nfattr *cda[])
{

	return -EOPNOTSUPP;
}

static int
ctnetlink_create_expect(struct nfattr *cda[])
{
	struct ip_conntrack_tuple *tuple, *mask;
	struct ip_conntrack_tuple *orig, *reply;
	struct ip_conntrack_tuple_hash *h = NULL;
	struct ip_conntrack_expect exp, *new;
	struct ip_conntrack_helper *helper;
	unsigned long timeout;
	int err;

	ct_debug(0, "entered\n");

	if (!(cda[CTA_ORIG-1] || cda[CTA_RPLY-1])) {
		ct_debug(0, "required attributes missing\n");
		return -EINVAL;
	}

	tuple = NFA_DATA(cda[CTA_EXP_TUPLE-1]);
	mask  = NFA_DATA(cda[CTA_EXP_MASK-1]);
	orig  = NFA_DATA(cda[CTA_ORIG-1]);
	reply = NFA_DATA(cda[CTA_RPLY-1]);

	memcpy(&exp.tuple, tuple, sizeof(struct ip_conntrack_tuple));
	memcpy(&exp.mask, mask, sizeof(struct ip_conntrack_tuple));

	exp.expectfn = NULL;

	if (cda[CTA_EXP_SEQNO-1])
		exp.seq = *(u_int32_t *)NFA_DATA(cda[CTA_EXP_SEQNO-1]);

	h = __ip_conntrack_find_get(orig, NULL);
	if (h == NULL)
		h = __ip_conntrack_find_get(reply, NULL);
	if (h == NULL)
		return -ENOENT;

	helper = h->ctrack->helper;

	if (cda[CTA_EXP_TIMEOUT-1])
		timeout = *(unsigned long *)NFA_DATA(cda[CTA_EXP_TIMEOUT-1]);
	else if (helper && helper->timeout)
		timeout = helper->timeout;
	else
		return -EINVAL;

	if (helper && helper->ctnl_new_expect) {
		struct cta_exp_proto *cp = NULL;
		struct cta_exp_help *ch = NULL;

		if (cda[CTA_EXP_PROTO-1])
			cp = NFA_DATA(cda[CTA_EXP_PROTO-1]);
		if (cda[CTA_EXP_HELP-1])
			ch = NFA_DATA(cda[CTA_EXP_HELP-1]);
		
		helper->ctnl_new_expect(&exp, &cp->proto, &ch->help);
	}

	err = __ip_conntrack_expect_related(h->ctrack, &exp, &new);
	if (err < 0)
		return err;
	
	new->timeout.expires = jiffies + timeout * HZ;
	add_timer(&new->timeout);
	return 0;
}

static int
ctnetlink_new_expect(struct sock *ctnl, struct sk_buff *skb,
		     struct nlmsghdr *nlh, int *errp)
{
	struct nfattr *cda[CTA_MAX];
	struct ip_conntrack_tuple *tuple, *mask;
	struct ip_conntrack_expect *exp;
	int i, err = 0;

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	for (i = 0; i < CTA_MAX; i++)
		if (cda[i] && NFA_PAYLOAD(cda[i]) < cta_min[i])
			return -EINVAL;

	if (!cda[CTA_EXP_TUPLE-1] || !cda[CTA_EXP_MASK-1])
		return -EINVAL;

	tuple = NFA_DATA(cda[CTA_EXP_TUPLE-1]);
	mask  = NFA_DATA(cda[CTA_EXP_MASK-1]);

	WRITE_LOCK(&ip_conntrack_lock);
	exp = __ip_ct_expect_find_tm(tuple, mask);

	if (exp == NULL) {
		err = -ENOENT;
		if (!(nlh->nlmsg_flags & NLM_F_CREATE))
			goto out_unlock;
		err = ctnetlink_create_expect(cda);
	} else {
		err = -EEXIST;
		if (nlh->nlmsg_flags & NLM_F_EXCL)
			goto out_unlock;
		err = ctnetlink_change_expect(exp, cda);
	}

out_unlock:
	WRITE_UNLOCK(&ip_conntrack_lock);
	return err;
}

/* struct conntrack_expect stuff */

static struct notifier_block ctnl_notifier = {
	ctnetlink_conntrack_event,
	NULL,
	0
};

static void __exit ctnetlink_exit(void)
{
	printk("ctnetlink: unregistering with nfnetlink.\n");
//	ip_conntrack_notify_unregister(&ctnl_exp_notify);
	ip_conntrack_notify_unregister(&ctnl_notifier);
	nfnetlink_subsys_unregister(ctnl_subsys);
	kfree(ctnl_subsys);
	return;
}

static int __init ctnetlink_init(void)
{
	int ret;

	ctnl_subsys = nfnetlink_subsys_alloc(CTNL_MSG_COUNT);
	if (!ctnl_subsys) {
		ret = -ENOMEM;
		goto err_out; 
	}

	ctnl_subsys->name = "conntrack";
	ctnl_subsys->subsys_id = NFNL_SUBSYS_CTNETLINK;
	ctnl_subsys->cb_count = CTNL_MSG_COUNT;
	ctnl_subsys->attr_count = CTA_MAX;
	ctnl_subsys->cb[CTNL_MSG_NEWCONNTRACK].call = ctnetlink_new_conntrack;
	ctnl_subsys->cb[CTNL_MSG_NEWCONNTRACK].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[CTNL_MSG_DELCONNTRACK].call = ctnetlink_del_conntrack;
	ctnl_subsys->cb[CTNL_MSG_DELCONNTRACK].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[CTNL_MSG_GETCONNTRACK].call = ctnetlink_get_conntrack;
	ctnl_subsys->cb[CTNL_MSG_GETCONNTRACK].cap_required = 0;
	ctnl_subsys->cb[CTNL_MSG_NEWEXPECT].call = ctnetlink_new_expect;
	ctnl_subsys->cb[CTNL_MSG_NEWEXPECT].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[CTNL_MSG_DELEXPECT].call = ctnetlink_del_expect;
	ctnl_subsys->cb[CTNL_MSG_DELEXPECT].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[CTNL_MSG_GETEXPECT].call = ctnetlink_get_expect;
	ctnl_subsys->cb[CTNL_MSG_GETEXPECT].cap_required = 0;

	printk("ctnetlink v%s: registering with nfnetlink.\n", ctversion);
	if ((ret = nfnetlink_subsys_register(ctnl_subsys) < 0)) {
		printk("ctnetlink_init: cannot register with nfnetlink.\n");
		goto err_free_subsys;
	}

	if ((ret = ip_conntrack_notify_register(&ctnl_notifier)) < 0) {
		printk("ctnetlink_init: cannot register notifier.\n");
		goto err_unreg_subsys;
	}

#if 0
	if ((ret = ip_conntrack_notify_register(&ctnl_exp_notify)) < 0) {
		printk("ctnetlink_init: cannot register exp notifier\n");
		goto err_unreg_notify;
	}
#endif


	return 0;
	
#if 0
err_unreg_notify:
	ip_conntrack_notify_unregister(&ctnl_notify);
#endif 
err_unreg_subsys:
	nfnetlink_subsys_unregister(ctnl_subsys);
err_free_subsys:
	kfree(ctnl_subsys);
err_out:
	return ret;
}

module_init(ctnetlink_init);
module_exit(ctnetlink_exit);
