/* $Id: sfsserv.C,v 1.5 2001/03/22 02:56:35 dm Exp $ */

/*
 *
 * Copyright (C) 2000 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "sfsserv.h"
#include <grp.h>

ptr<aclnt>
getauthclnt ()
{
  static ptr<axprt> authxprt;
  static ptr<aclnt> authclnt;
  if (authclnt && !authxprt->ateof ())
    return authclnt;
  int fd = suidgetfd ("authserv");
  if (fd < 0) {
    authxprt = NULL;
    authclnt = NULL;
    return NULL;
  }
  authclnt = aclnt::alloc (authxprt = axprt_stream::alloc (fd),
			   sfsauth_program_1);
  return authclnt;
}

sfsserv::sfsserv (ref<axprt_crypt> xxc, ptr<axprt> xx)
  : xc (xxc), x (xx ? xx : implicit_cast<ptr<axprt> > (xxc)),
    destroyed (New refcounted<bool> (false)),
    sfssrv (asrv::alloc (x, sfs_program_1, wrap (this, &sfsserv::dispatch))),
    seqstate (128), authid_valid (false)
{
  authtab.push_back ();
  credtab.push_back ();
}

sfsserv::~sfsserv ()
{
  *destroyed = true;
}

u_int32_t
sfsserv::authalloc ()
{
  if (authfreelist.size ())
    return authfreelist.pop_back ();
  if (authtab.size () >= 0x10000)
    return 0;
  authtab.push_back ();
  credtab.push_back ();
  return authtab.size () - 1;
}
void
sfsserv::authfree (size_t n)
{
  if (n && n < authtab.size () && authtab[n]) {
    authtab[n] = NULL;
    credtab[n].set_type (SFS_NOCRED);
    authfreelist.push_back (n);
  }
}

void
sfsserv::dispatch (svccb *sbp)
{
  if (!sbp)
    return;

  switch (sbp->proc ()) {
  case SFSPROC_NULL:
    sbp->reply (NULL);
    break;
  case SFSPROC_CONNECT:
    sfs_connect (sbp);
    break;
  case SFSPROC_ENCRYPT:
    sfs_encrypt (sbp);
    break;
  case SFSPROC_GETFSINFO:
    sfs_getfsinfo (sbp);
    break;
  case SFSPROC_LOGIN:
    sfs_login (sbp);
    break;
  case SFSPROC_LOGOUT:
    sfs_logout (sbp);
    break;
  case SFSPROC_IDNAMES:
    sfs_idnames (sbp);
    break;
  case SFSPROC_IDNUMS:
    sfs_idnums (sbp);
    break;
  case SFSPROC_GETCRED:
    sfs_getcred (sbp);
    break;
  default:
    sfs_badproc (sbp);
    return;
  }
}

void
sfsserv::sfs_connect (svccb *sbp)
{
  if (cd || authid_valid) {
    sbp->reject (PROC_UNAVAIL);
    return;
  }
  cd.alloc ();
  cd->ci = *sbp->template getarg <sfs_connectarg> ();
  cd->cr.set_status (SFS_OK);
  cd->cr.reply->charge.bitcost = sfs_hashcost;
  rnd.getbytes (cd->cr.reply->charge.target.base (), charge.target.size ());
  sk = doconnect (&cd->ci, &cd->cr.reply->servinfo);
  if (!sk && !cd->cr.status)
    cd->cr.set_status (SFS_NOSUCHHOST);
  sbp->reply (&cd->cr);
}

void
sfsserv::sfs_encrypt (svccb *sbp)
{
  if (!cd || cd->cr.status) {
    sbp->reject (PROC_UNAVAIL);
    return;
  }
  sfs_server_crypt (sbp, sk, cd->ci, cd->cr.reply->servinfo,
		    &sessid, cd->cr.reply->charge, xc);
  sfs_hash hostid;
  bool hostid_ok = sfs_mkhostid (&hostid, cd->cr.reply->servinfo.host);
  assert (hostid_ok);
  sfs_get_authid (&authid, cd->ci.service,
		  cd->cr.reply->servinfo.host.hostname,
		  &hostid, &sessid);
  authid_valid = true;
  cd.clear ();
}

static void
sfs_login_cb (ref<bool> destroyed, sfsserv *srv, svccb *sbp,
	      sfsauth_loginres *_resp, clnt_stat stat)
{
  auto_ptr<sfsauth_loginres> resp (_resp);
  if (stat || *destroyed) {
    if (stat)
      warn << "authserv: " << stat << "\n";
    sbp->replyref (sfs_loginres (SFSLOGIN_ALLBAD));
    return;
  }

  sfs_loginres res (resp->status);
  switch (resp->status) {
  case SFSLOGIN_OK:
    {
      if (resp->resok->authid != srv->authid
	  || !srv->seqstate.check (resp->resok->seqno)
	  || resp->resok->cred.type != SFS_UNIXCRED) {
	res.set_status (SFSLOGIN_BAD);
	break;
      }

      u_int32_t authno = srv->authalloc ();
      if (!authno) {
	warn << "ran out of authnos\n";
	res.set_status (SFSLOGIN_BAD);
	break;
      }

      *res.authno = authno;
      sfs_unixcred &uc = *resp->resok->cred.unixcred;
      srv->authtab[authno] = authunixint_create ("localhost", uc.uid, uc.gid,
						 uc.groups.size (),
						 uc.groups.base ());
      srv->credtab[authno] = resp->resok->cred;
      break;
    }
  case SFSLOGIN_MORE:
    *res.resmore = *resp->resmore;
    break;
  default:
    break;
  }
  sbp->reply (&res);
}
void
sfsserv::sfs_login (svccb *sbp)
{
  ptr<aclnt> c;
  if (!authid_valid || !(c = getauthclnt ())) {
    sbp->replyref (sfs_loginres (SFSLOGIN_ALLBAD));
    return;
  }
  sfsauth_loginres *resp = New sfsauth_loginres;
  c->call (SFSAUTHPROC_LOGIN, sbp->template getarg<sfs_loginarg> (), resp,
	   wrap (sfs_login_cb, destroyed, this, sbp, resp));
}

void
sfsserv::sfs_logout (svccb *sbp)
{
  authfree (*sbp->template getarg<u_int32_t> ());
  sbp->reply (NULL);
}

// XXX - MAJOR DEADLOCK PROBLEMS HERE
// XXX - should never call getpw* or getgr*
void
sfsserv::sfs_idnames (svccb *sbp)
{
  if (!getauth (sbp->getaui ())) {
    sbp->reject (AUTH_REJECTEDCRED);
    return;
  }

  ::sfs_idnums *argp = sbp->template getarg< ::sfs_idnums> ();
  ::sfs_idnames res;
  if (argp->uid != -1)
    if (struct passwd *p = getpwuid (argp->uid)) {
      res.uidname.set_present (true);
      *res.uidname.name = p->pw_name;
    }
  if (argp->gid != -1)
    if (struct group *g = getgrgid (argp->gid)) {
      res.gidname.set_present (true);
      *res.gidname.name = g->gr_name;
    }
  sbp->reply (&res);
}

// XXX - MAJOR DEADLOCK PROBLEMS HERE
// XXX - should never call getpw* or getgr*
void
sfsserv::sfs_idnums (svccb *sbp)
{
  if (!getauth (sbp->getaui ())) {
    sbp->reject (AUTH_REJECTEDCRED);
    return;
  }

  ::sfs_idnames *argp = sbp->template getarg< ::sfs_idnames> ();
  ::sfs_idnums res = { -1, -1 };
  if (argp->uidname.present)
    if (struct passwd *p = getpwnam (argp->uidname.name->cstr ()))
      res.uid = p->pw_uid;
  if (argp->gidname.present)
    if (struct group *g = getgrnam (argp->gidname.name->cstr ()))
      res.gid = g->gr_gid;
  sbp->reply (&res);
}

void
sfsserv::sfs_getcred (svccb *sbp)
{
  u_int32_t authno = sbp->getaui ();
  if (authno < credtab.size ())
    sbp->replyref (credtab[authno]);
  else
    sbp->replyref (sfsauth_cred (SFS_NOCRED));
}


static ptr<axprt_stream>
sfs_accept (sfsserv_cb cb, int fd)
{
  if (fd < 0) {
    (*cb) (NULL);
    return NULL;
  }
  tcp_nodelay (fd);
  ref<axprt_crypt> x = axprt_crypt::alloc (fd);
  (*cb) (x);
  return x;
}

static void
sfs_accept_standalone (sfsserv_cb cb, int sfssfd)
{
  sockaddr_in sin;
  bzero (&sin, sizeof (sin));
  socklen_t sinlen = sizeof (sin);
  int fd = accept (sfssfd, reinterpret_cast<sockaddr *> (&sin), &sinlen);
  if (fd >= 0)
    sfs_accept (cb, fd);
  else if (errno != EAGAIN)
    warn ("accept: %m\n");
}

void
sfssd_slave (sfsserv_cb cb)
{
  if (cloneserv (0, wrap (sfs_accept, cb)))
    return;
  warn ("No sfssd detected, running in standalone mode.\n");
  int sfssfd = inetsocket (SOCK_STREAM, sfs_port);
  if (sfssfd < 0)
    fatal ("binding TCP port %d: %m\n", sfs_port);
  close_on_exec (sfssfd);
  listen (sfssfd, 5);
  fdcb (sfssfd, selread, wrap (sfs_accept_standalone, cb, sfssfd));
}
