/*
 * Copyright (C) 2008-2009 Martin Willi
 * Hochschule fuer Technik Rapperswil
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2 of the License, or (at your
 * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * for more details.
 */

#include "agent_private_key.h"

#include <stdlib.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <arpa/inet.h>
#include <errno.h>

#include <library.h>
#include <chunk.h>
#include <debug.h>

#ifndef UNIX_PATH_MAX
#define UNIX_PATH_MAX 108
#endif /* UNIX_PATH_MAX */

typedef struct private_agent_private_key_t private_agent_private_key_t;
typedef enum agent_msg_type_t agent_msg_type_t;

/**
 * Private data of a agent_private_key_t object.
 */
struct private_agent_private_key_t {
	/**
	 * Public interface for this signer.
	 */
	agent_private_key_t public;

	/**
	 * ssh-agent unix socket connection
	 */
	int socket;

	/**
	 * key identity blob in ssh format
	 */
	chunk_t key;

	/**
	 * keysize in bytes
	 */
	size_t key_size;

	/**
	 * reference count
	 */
	refcount_t ref;
};

/**
 * Message types for ssh-agent protocol
 */
enum agent_msg_type_t {
	SSH_AGENT_FAILURE = 5,
	SSH_AGENT_SUCCESS =	6,
	SSH_AGENT_ID_REQUEST = 11,
	SSH_AGENT_ID_RESPONSE = 12,
	SSH_AGENT_SIGN_REQUEST = 13,
	SSH_AGENT_SIGN_RESPONSE = 14,
};

/**
 * read a byte from a blob
 */
static u_char read_byte(chunk_t *blob)
{
	u_char val;

	if (blob->len < sizeof(u_char))
	{
		return 0;
	}
	val = *(blob->ptr);
	*blob = chunk_skip(*blob, sizeof(u_char));
	return val;
}

/**
 * read a u_int32_t from a blob
 */
static u_int32_t read_uint32(chunk_t *blob)
{
	u_int32_t val;

	if (blob->len < sizeof(u_int32_t))
	{
		return 0;
	}
	val = ntohl(*(u_int32_t*)blob->ptr);
	*blob = chunk_skip(*blob, sizeof(u_int32_t));
	return val;
}

/**
 * read a ssh-agent "string" length/value from a blob
 */
static chunk_t read_string(chunk_t *blob)
{
	int len;
	chunk_t str;

	len = read_uint32(blob);
	if (len > blob->len)
	{
		return chunk_empty;
	}
	str = chunk_create(blob->ptr, len);
	*blob = chunk_skip(*blob, + len);
	return str;
}

/**
 * open socket connection to the ssh-agent
 */
static int open_connection(char *path)
{
	struct sockaddr_un addr;
	int s;

	s = socket(AF_UNIX, SOCK_STREAM, 0);
	if (s == -1)
	{
		DBG1(DBG_LIB, "opening ssh-agent socket %s failed: %s:", path,
			 strerror(errno));
		return -1;
	}

	addr.sun_family = AF_UNIX;
	addr.sun_path[UNIX_PATH_MAX - 1] = '\0';
	strncpy(addr.sun_path, path, UNIX_PATH_MAX - 1);

	if (connect(s, (struct sockaddr*)&addr, SUN_LEN(&addr)) != 0)
	{
		DBG1(DBG_LIB, "connecting to ssh-agent socket failed: %s",
			 strerror(errno));
		close(s);
		return -1;
	}
	return s;
}

/**
 * Get the first usable key from the agent
 */
static bool read_key(private_agent_private_key_t *this, public_key_t *pubkey)
{
	int len, count;
	char buf[2048];
	chunk_t blob, key, type, n;

	len = htonl(1);
	buf[0] = SSH_AGENT_ID_REQUEST;
	if (write(this->socket, &len, sizeof(len)) != sizeof(len) ||
		write(this->socket, &buf, 1) != 1)
	{
		DBG1(DBG_LIB, "writing to ssh-agent failed");
		return FALSE;
	}

	blob = chunk_create(buf, sizeof(buf));
	blob.len = read(this->socket, blob.ptr, blob.len);

	if (blob.len < sizeof(u_int32_t) + sizeof(u_char) ||
		read_uint32(&blob) != blob.len ||
		read_byte(&blob) != SSH_AGENT_ID_RESPONSE)
	{
		DBG1(DBG_LIB, "received invalid ssh-agent identity response");
		return FALSE;
	}
	count = read_uint32(&blob);

	while (blob.len)
	{
		key = read_string(&blob);
		if (!key.len)
		{
			break;
		}
		this->key = key;
		type = read_string(&key);
		if (!type.len || !strneq("ssh-rsa", type.ptr, type.len))
		{
			break;
		}
		read_string(&key);
		n = read_string(&key);
		if (n.len <= 512/8)
		{
			break;;
		}
		if (pubkey && !private_key_belongs_to(&this->public.interface, pubkey))
		{
			continue;
		}
		this->key_size = n.len;
		if (n.ptr[0] == 0)
		{
			this->key_size--;
		}
		this->key = chunk_clone(this->key);
		return TRUE;
	}
	this->key = chunk_empty;
	return FALSE;
}

/**
 * Implementation of agent_private_key.destroy.
 */
static bool sign(private_agent_private_key_t *this, signature_scheme_t scheme,
				 chunk_t data, chunk_t *signature)
{
	u_int32_t len, flags;
	char buf[2048];
	chunk_t blob;

	if (scheme != SIGN_RSA_EMSA_PKCS1_SHA1)
	{
		DBG1(DBG_LIB, "signature scheme %N not supported by ssh-agent",
			 signature_scheme_names, scheme);
		return FALSE;
	}

	len = htonl(1 + sizeof(u_int32_t) * 3 + this->key.len + data.len);
	buf[0] = SSH_AGENT_SIGN_REQUEST;
	if (write(this->socket, &len, sizeof(len)) != sizeof(len) ||
		write(this->socket, &buf, 1) != 1)
	{
		DBG1(DBG_LIB, "writing to ssh-agent failed");
		return FALSE;
	}

	len = htonl(this->key.len);
	if (write(this->socket, &len, sizeof(len)) != sizeof(len) ||
		write(this->socket, this->key.ptr, this->key.len) != this->key.len)
	{
		DBG1(DBG_LIB, "writing to ssh-agent failed");
		return FALSE;
	}

	len = htonl(data.len);
	if (write(this->socket, &len, sizeof(len)) != sizeof(len) ||
		write(this->socket, data.ptr, data.len) != data.len)
	{
		DBG1(DBG_LIB, "writing to ssh-agent failed");
		return FALSE;
	}

	flags = htonl(0);
	if (write(this->socket, &flags, sizeof(flags)) != sizeof(flags))
	{
		DBG1(DBG_LIB, "writing to ssh-agent failed");
		return FALSE;
	}

	blob = chunk_create(buf, sizeof(buf));
	blob.len = read(this->socket, blob.ptr, blob.len);
	if (blob.len < sizeof(u_int32_t) + sizeof(u_char) ||
		read_uint32(&blob) != blob.len ||
		read_byte(&blob) != SSH_AGENT_SIGN_RESPONSE)
	{
		DBG1(DBG_LIB, "received invalid ssh-agent signature response");
		return FALSE;
	}
	/* parse length */
	blob = read_string(&blob);
	/* skip sig type */
	read_string(&blob);
	/* parse length */
	blob = read_string(&blob);
	if (!blob.len)
	{
		DBG1(DBG_LIB, "received invalid ssh-agent signature response");
		return FALSE;
	}
	*signature =  chunk_clone(blob);
	return TRUE;
}

/**
 * Implementation of agent_private_key.destroy.
 */
static key_type_t get_type(private_agent_private_key_t *this)
{
	return KEY_RSA;
}

/**
 * Implementation of agent_private_key.destroy.
 */
static bool decrypt(private_agent_private_key_t *this,
					chunk_t crypto, chunk_t *plain)
{
	DBG1(DBG_LIB, "private key decryption not supported by ssh-agent");
	return FALSE;
}

/**
 * Implementation of agent_private_key.destroy.
 */
static size_t get_keysize(private_agent_private_key_t *this)
{
	return this->key_size;
}

/**
 * Implementation of agent_private_key.get_public_key.
 */
static public_key_t* get_public_key(private_agent_private_key_t *this)
{
	chunk_t key, n, e;

	key = this->key;
	read_string(&key);
	e = read_string(&key);
	n = read_string(&key);

	return lib->creds->create(lib->creds, CRED_PUBLIC_KEY, KEY_RSA,
						BUILD_RSA_MODULUS, n, BUILD_RSA_PUB_EXP, e, BUILD_END);
}

/**
 * Implementation of private_key_t.get_encoding
 */
static bool get_encoding(private_agent_private_key_t *this,
						 cred_encoding_type_t type, chunk_t *encoding)
{
	return FALSE;
}

/**
 * Implementation of private_key_t.get_fingerprint
 */
static bool get_fingerprint(private_agent_private_key_t *this,
							cred_encoding_type_t type, chunk_t *fp)
{
	chunk_t n, e, key;

	if (lib->encoding->get_cache(lib->encoding, type, this, fp))
	{
		return TRUE;
	}
	key = this->key;
	read_string(&key);
	e = read_string(&key);
	n = read_string(&key);

	return lib->encoding->encode(lib->encoding, type, this, fp,
			CRED_PART_RSA_MODULUS, n, CRED_PART_RSA_PUB_EXP, e, CRED_PART_END);
}

/**
 * Implementation of agent_private_key.get_ref.
 */
static private_agent_private_key_t* get_ref(private_agent_private_key_t *this)
{
	ref_get(&this->ref);
	return this;
}

/**
 * Implementation of agent_private_key.destroy.
 */
static void destroy(private_agent_private_key_t *this)
{
	if (ref_put(&this->ref))
	{
		close(this->socket);
		free(this->key.ptr);
		lib->encoding->clear_cache(lib->encoding, this);
		free(this);
	}
}

/**
 * See header.
 */
agent_private_key_t *agent_private_key_open(key_type_t type, va_list args)
{
	private_agent_private_key_t *this;
	public_key_t *pubkey = NULL;
	char *path = NULL;

	while (TRUE)
	{
		switch (va_arg(args, builder_part_t))
		{
			case BUILD_AGENT_SOCKET:
				path = va_arg(args, char*);
				continue;
			case BUILD_PUBLIC_KEY:
				pubkey = va_arg(args, public_key_t*);
				continue;
			case BUILD_END:
				break;
			default:
				return NULL;
		}
		break;
	}
	if (!path)
	{
		return FALSE;
	}

	this = malloc_thing(private_agent_private_key_t);

	this->public.interface.get_type = (key_type_t (*)(private_key_t *this))get_type;
	this->public.interface.sign = (bool (*)(private_key_t *this, signature_scheme_t scheme, chunk_t data, chunk_t *signature))sign;
	this->public.interface.decrypt = (bool (*)(private_key_t *this, chunk_t crypto, chunk_t *plain))decrypt;
	this->public.interface.get_keysize = (size_t (*) (private_key_t *this))get_keysize;
	this->public.interface.get_public_key = (public_key_t* (*)(private_key_t *this))get_public_key;
	this->public.interface.belongs_to = private_key_belongs_to;
	this->public.interface.equals = private_key_equals;
	this->public.interface.get_fingerprint = (bool(*)(private_key_t*, cred_encoding_type_t type, chunk_t *fp))get_fingerprint;
	this->public.interface.has_fingerprint = (bool(*)(private_key_t*, chunk_t fp))private_key_has_fingerprint;
	this->public.interface.get_encoding = (bool(*)(private_key_t*, cred_encoding_type_t type, chunk_t *encoding))get_encoding;
	this->public.interface.get_ref = (private_key_t* (*)(private_key_t *this))get_ref;
	this->public.interface.destroy = (void (*)(private_key_t *this))destroy;

	this->socket = open_connection(path);
	if (this->socket < 0)
	{
		free(this);
		return NULL;
	}
	this->key = chunk_empty;
	this->ref = 1;

	if (!read_key(this, pubkey))
	{
		destroy(this);
		return NULL;
	}
	return &this->public;
}

