/*
** Copyright (C) 2001 Jed Pickel <jed@pickel.nt>
** Portions Copyright (C) 2000,2001 Carnegie Mellon University
** Portions Copyright (C) 2001 Andrew R. Baker <andrewb@farm9.com>
**
** This program is free software; you can redistribute it and/or modify
** it under the terms of the GNU General Public License as published by
** the Free Software Foundation; either version 2 of the License, or
** (at your option) any later version.
**
** This program is distributed in the hope that it will be useful,
** but WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
** GNU General Public License for more details.
**
** You should have received a copy of the GNU General Public License
** along with this program; if not, write to the Free Software
** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*/

/* $Id: spo_database.c,v 1.37.2.2 2002/01/16 16:57:01 fygrave Exp $ */

/* Snort Database Output Plug-in by Jed Pickel <jed@pickel.net>
 * 
 * See the README.database file with this distribution 
 * documentation or the snortdb web site for configuration
 * information
 *
 * Web Site: http://www.incident.org/snortdb
 * */

#include "spo_database.h"
extern PV pv;


/* list for lookup of shared data information */
typedef struct _SharedDatabaseDataNode
{
    SharedDatabaseData *data;
    struct _SharedDatabaseDataNode *next;
} SharedDatabaseDataNode;

static SharedDatabaseDataNode *sharedDataList = NULL;
static int instances = 0;

/* Locally defined functions */
void FreeSharedDataList();


/* If you want extra debugging information for solving database 
   configuration problems, uncomment the following line. */
/* #define DEBUG */


/* The following is for supporting Microsoft SQL Server */
#ifdef ENABLE_MSSQL

/* If you want extra debugging information (specific to
   Microsoft SQL Server), uncomment the following line. */
#define ENABLE_MSSQL_DEBUG

#if defined(DEBUG) || defined(ENABLE_MSSQL_DEBUG)
    /* this is for debugging purposes only */
    static char g_CurrentStatement[2048];
    #define SAVESTATEMENT(str)   strncpy(g_CurrentStatement, str, sizeof(g_CurrentStatement) - 1);
    #define CLEARSTATEMENT()     bzero((char *) g_CurrentStatement, sizeof(g_CurrentStatement));
#else
    #define SAVESTATEMENT(str)   NULL;
    #define CLEARSTATEMENT()     NULL;
#endif /* DEBUG || ENABLE_MSSQL_DEBUG*/

    /* Predeclaration of SQL Server callback functions. See actual declaration elsewhere for details. */
    static int mssql_err_handler(PDBPROCESS dbproc, int severity, int dberr, int oserr, LPCSTR dberrstr, LPCSTR oserrstr);
    static int mssql_msg_handler(PDBPROCESS dbproc, DBINT msgno, int msgstate, int severity, LPCSTR msgtext, LPCSTR srvname, LPCSTR procname, DBUSMALLINT line);

#endif /* ENABLE_MSSQL */


/*
 * Function: SetupDatabase()
 *
 * Purpose: Registers the output plugin keyword and initialization 
 *          function into the output plugin list.  This is the function that
 *          gets called from InitOutputPlugins() in plugbase.c.
 *
 * Arguments: None.
 *
 * Returns: void function
 *
 */
void SetupDatabase()
{
    /* link the preprocessor keyword to the init function in 
       the preproc list */
    RegisterOutputPlugin("database", NT_OUTPUT_ALERT, DatabaseInit);

#ifdef DEBUG
    printf("database(debug): database plugin is registered...\n");
#endif
}

/*
 * Function: DatabaseInit(u_char *)
 *
 * Purpose: Calls the argument parsing function, performs final setup on data
 *          structs, links the preproc function into the function list.
 *
 * Arguments: args => ptr to argument string
 *
 * Returns: void function
 *
 */
void DatabaseInit(u_char *args)
{
    DatabaseData *data;
    char * select0;
    char * select1;
    char * insert0;
    int foundEntry = 0;
    SharedDatabaseDataNode *current;

    /* parse the argument list from the rules file */
    data = ParseDatabaseArgs(args);

    /* find a unique name for sensor if one was not supplied as an option */
    if(!data->sensor_name)
    {
        data->sensor_name = GetUniqueName((char *)PRINT_INTERFACE(pv.interfaces[0]));
        if( !pv.quiet_flag ) printf("database:   sensor name = %s\n", data->sensor_name);
    }

    data->tz = GetLocalTimezone();

    /* allocate memory for configuration queries */
    select0 = (char *)calloc(MAX_QUERY_LENGTH, sizeof(char));
    select1 = (char *)calloc(MAX_QUERY_LENGTH, sizeof(char));
    insert0 = (char *)calloc(MAX_QUERY_LENGTH, sizeof(char));

    if(pv.pcap_cmd == NULL)
    {
        snprintf(insert0, MAX_QUERY_LENGTH, 
                 "INSERT INTO sensor (hostname, interface, detail, encoding) "
                 "VALUES ('%s','%s','%u','%u')", 
                 data->sensor_name, PRINT_INTERFACE(pv.interfaces[0]), data->detail, data->encoding);
        snprintf(select0, MAX_QUERY_LENGTH, 
                 "SELECT sid FROM sensor WHERE hostname = '%s' "
                 "AND interface = '%s' AND detail = '%u' AND "
                 "encoding = '%u' AND filter IS NULL",
                 data->sensor_name, PRINT_INTERFACE(pv.interfaces[0]), data->detail, data->encoding);
    }
    else
    {
        snprintf(select0, MAX_QUERY_LENGTH, 
                 "SELECT sid FROM sensor WHERE hostname = '%s' "
                 "AND interface = '%s' AND filter ='%s' AND "
                 "detail = '%u' AND encoding = '%u'",
                 data->sensor_name, PRINT_INTERFACE(pv.interfaces[0]), pv.pcap_cmd,
                 data->detail, data->encoding);
        snprintf(insert0, MAX_QUERY_LENGTH, 
                 "INSERT INTO sensor (hostname, interface, filter,"
                 "detail, encoding) "
                 "VALUES ('%s','%s','%s','%u','%u')", 
                 data->sensor_name, PRINT_INTERFACE(pv.interfaces[0]), pv.pcap_cmd,
                 data->detail, data->encoding);
    }

    Connect(data);

    data->shared->sid = Select(select0,data);
    if(data->shared->sid == 0)
    {
        Insert(insert0,data);
        data->shared->sid = Select(select0,data);
        if(data->shared->sid == 0)
        {
            ErrorMessage("database: Problem obtaining SENSOR ID (sid) from %s->%s->sensor\n", data->shared->dbtype,data->shared->dbname);
            FatalError("\n"
                       " When this plugin starts, a SELECT query is run to find the sensor id for the\n"
                       " currently running sensor. If the sensor id is not found, the plugin will run\n"
                       " an INSERT query to insert the proper data and generate a new sensor id. Then a\n"
                       " SELECT query is run to get the newly allocated sensor id. If that fails then\n"
                       " this error message is generated.\n"
                       "\n"
                       " Some possible causes for this error are:\n"
                       " * the user does not have proper INSERT or SELECT privileges\n"
                       " * the sensor table does not exist\n"
                       "\n"
                       " If you are _absolutly_ certain that you have the proper privileges set and\n"
                       " that your database structure is built properly please let me know if you\n"
                       " continue to get this error. You can contact me at (jed@pickel.net).\n"
                       "\n");
        }
    }

    if( !pv.quiet_flag ) printf("database:     sensor id = %u\n", data->shared->sid);

	/* the cid may be shared across multiple instances of the database
	   plugin, first we check the shared data list to see if we already
	   have a value to use, if so, we replace the SharedDatabaseData struct
	   in the DatabaseData struct with the one out of the sharedDataList.
	   Sound confusing enough?  
	   -Andrew	
     */
	/* XXX: Creating a set of list handling functions would make this cleaner */
	current = sharedDataList;
	while(current != NULL)
	{
		/* We have 4 key fields to check */
		/* XXX: This code would be cleaner if dbtype was an int */
		if((current->data->sid == data->shared->sid) &&
				(strcasecmp(current->data->dbtype, data->shared->dbtype) == 0) &&
				/* XXX: should this be a case insensitive compare? */
				(strcasecmp(current->data->dbname, data->shared->dbname) == 0) &&
				(strcasecmp(current->data->host, data->shared->host) == 0))
		{
			foundEntry = 1;
			break;
		}
		current = current->next;
	}
	
    if(foundEntry == 0)
    {
        /* Add it the the shared data list */
        SharedDatabaseDataNode *newNode = (SharedDatabaseDataNode *)calloc(1,sizeof(SharedDatabaseDataNode));
        newNode->data = data->shared;
        newNode->next = NULL;
        if(sharedDataList == NULL)
        {
            sharedDataList = newNode;
        }
        else
        {
            current = sharedDataList;
            while(current->next != NULL)
                current = current->next;
            current->next = newNode;
        }
        /* Set the cid value */
        snprintf(select1, MAX_QUERY_LENGTH,
        	     "SELECT max(cid) FROM event WHERE sid = '%u'", data->shared->sid);

        data->shared->cid = Select(select1,data);
        ++(data->shared->cid);
    }
    else
    {
        /* Free memory associated with data->shared */
        free(data->shared);
        data->shared = current->data;
    }

    /* free memory */
    free(select0);
    free(select1);
    free(insert0);

    /* Get the versioning information for the DB schema */
    data->DBschema_version = CheckDBVersion(data);
    if( !pv.quiet_flag ) printf("database: schema version = %d\n", data->DBschema_version);
    if ( data->DBschema_version < LATEST_DB_SCHEMA_VERSION )
    {
       FatalError("database: The underlying database seems to be running an older version of the DB schema.\n"
                  "          Please re-run the appropriate DB creation script (e.g. create_mysql,\n"
                  "          create_postgresql, create_oracle) located in the contrib\\ directory.\n");
    }

    /* Add the processor function into the function list */
    if(!strncasecmp(data->facility,"log",3))
    {
        pv.log_plugin_active = 1;
        if( !pv.quiet_flag ) printf("database: using the \"log\" facility\n");
        AddFuncToOutputList(Database, NT_OUTPUT_LOG, data);
    }
    else
    {
        pv.alert_plugin_active = 1;
        if( !pv.quiet_flag ) printf("database: using the \"alert\" facility\n");
        AddFuncToOutputList(Database, NT_OUTPUT_ALERT, data);
    }

    AddFuncToCleanExitList(SpoDatabaseCleanExitFunction, data);
    AddFuncToRestartList(SpoDatabaseRestartFunction, data); 
    ++instances;
}

/*
 * Function: ParseDatabaseArgs(char *)
 *
 * Purpose: Process the preprocessor arguements from the rules file and 
 *          initialize the preprocessor's data struct.
 *
 * Arguments: args => argument list
 *
 * Returns: void function
 *
 */
DatabaseData *ParseDatabaseArgs(char *args)
{
    DatabaseData *data;
    char *dbarg;
    char *a1;
    char *type;
    char *facility;

    data = (DatabaseData *)calloc(1, sizeof(DatabaseData));
    data->shared = (SharedDatabaseData *)calloc(1, sizeof(SharedDatabaseData));

    if(args == NULL)
    {
        ErrorMessage("database: you must supply arguments for database plugin\n");
        DatabasePrintUsage();
        FatalError("");
    }

    data->shared->dbtype = NULL;
    data->sensor_name = NULL;
    data->facility = NULL;
    data->encoding = ENCODING_HEX;
    data->detail = DETAIL_FULL;

    facility = strtok(args, ", ");
    if(facility != NULL)
    {
        if((!strncasecmp(facility,"log",3)) || (!strncasecmp(facility,"alert",5)))
        {
            data->facility = facility;
        }
        else
        {
            ErrorMessage("database: The first argument needs to be the logging facility\n");
            DatabasePrintUsage();
            FatalError("");
        }
    }
    else
    {
        ErrorMessage("database: Invalid format for first argment\n"); 
        DatabasePrintUsage();
        FatalError("");
    }

    type = strtok(NULL, ", ");

    if(type == NULL)
    {
        ErrorMessage("database: you must enter the database type in configuration file as the second argument\n");
        DatabasePrintUsage();
        FatalError("");
    }

    /* print out and test the capability of this plugin */
    if( !pv.quiet_flag ) printf("database: compiled support for ( ");


#ifdef ENABLE_MYSQL
    if( !pv.quiet_flag ) printf("%s ",MYSQL);
    if(!strncasecmp(type,MYSQL,5))
    {
        data->shared->dbtype = type; 
    }
#endif
#ifdef ENABLE_POSTGRESQL
    if( !pv.quiet_flag ) printf("%s ",POSTGRESQL);
    if(!strncasecmp(type,POSTGRESQL,10))
    {
        data->shared->dbtype = type; 
    }
#endif
#ifdef ENABLE_ODBC
    if( !pv.quiet_flag ) printf("%s ",ODBC);
    if(!strncasecmp(type,ODBC,8))
    {
        data->shared->dbtype = type; 
    }
#endif
#ifdef ENABLE_ORACLE
    if( !pv.quiet_flag ) printf("%s ",ORACLE);
    if(!strncasecmp(type,ORACLE,5))
    {
      data->shared->dbtype = type; 
    }
#endif
#ifdef ENABLE_MSSQL
    if( !pv.quiet_flag ) printf("%s ",MSSQL);
    if(!strncasecmp(type,MSSQL,5))
    {
      data->shared->dbtype = type; 
    }
#endif

    if( !pv.quiet_flag ) printf(")\n");

    if( !pv.quiet_flag ) printf("database: configured to use %s\n", type);

    if(data->shared->dbtype == NULL)
    {
        ErrorMessage("database: %s support is not compiled in this copy\n\n", type);
        FatalError(" Check your configuration file to be sure you did not mis-spell \"%s\".\n If you did not, you will need to reconfigure and recompile ensuring that\n you have set the correct options to the configure script. Type \n \"./configure --help\" to see options for the configure script.\n\n", type);
    }

    dbarg = strtok(NULL, " =");
    while(dbarg != NULL)
    {
        a1 = NULL;
        a1 = strtok(NULL, ", ");
        if(!strncasecmp(dbarg,"host",4))
        {
            data->shared->host = a1;
            if( !pv.quiet_flag ) printf("database:          host = %s\n", data->shared->host);
        }
        if(!strncasecmp(dbarg,"port",4))
        {
            data->port = a1;
            if( !pv.quiet_flag ) printf("database:          port = %s\n", data->port);
        }
        if(!strncasecmp(dbarg,"user",4))
        {
            data->user = a1;
            if( !pv.quiet_flag ) printf("database:          user = %s\n", data->user);
        }
        if(!strncasecmp(dbarg,"password",8))
        {
            if( !pv.quiet_flag ) printf("database: password is set\n");
            data->password = a1;
        }
        if(!strncasecmp(dbarg,"dbname",6))
        {
            data->shared->dbname = a1;
            if( !pv.quiet_flag ) printf("database: database name = %s\n", data->shared->dbname);
        }
        if(!strncasecmp(dbarg,"sensor_name",11))
        {
            data->sensor_name = a1;
            if( !pv.quiet_flag ) printf("database:   sensor name = %s\n", data->sensor_name);
        }
        if(!strncasecmp(dbarg,"encoding",6))
        {
            if(!strncasecmp(a1, "hex", 3))
            {
                data->encoding = ENCODING_HEX;
            }
            else
            {
                if(!strncasecmp(a1, "base64", 6))
                {
                    data->encoding = ENCODING_BASE64;
                }
                else
                {
                    if(!strncasecmp(a1, "ascii", 5))
                    {
                        data->encoding = ENCODING_ASCII;
                    }
                    else
                    {
                        FatalError("database: unknown  (%s)", a1);
                    }
                }
            }
            if( !pv.quiet_flag ) printf("database: data encoding = %s\n", a1);
        }
        if(!strncasecmp(dbarg,"detail",6))
        {
            if(!strncasecmp(a1, "full", 4))
            {
                data->detail = DETAIL_FULL;
            }
            else
            {
                if(!strncasecmp(a1, "fast", 4))
                {
                    data->detail = DETAIL_FAST;
                }
                else
                {
                    FatalError("database: unknown detail level (%s)", a1);
                }
            } 
            if( !pv.quiet_flag ) printf("database: detail level  = %s\n", a1);
        }
        dbarg = strtok(NULL, "=");
    } 

    if(data->shared->dbname == NULL)
    {
        ErrorMessage("database: must enter database name in configuration file\n\n");
        DatabasePrintUsage();
        FatalError("");
    }

    return data;
}

void FreeQueryNode(SQLQuery * node)
{
/*   SQLQuery *back = node, *next = node;

   if ( !node )
      return;
   else
   {
     while (next)
     {
        next = back->next;
        free(back->val);
        free(back);
        back = next;
     }
   }
*/

    if(node)
    {
        FreeQueryNode(node->next);
        free(node->val);
        free(node);
    }
}

SQLQuery * NewQueryNode(SQLQuery * parent, int query_size)
{
    SQLQuery * rval;

    if(query_size == 0)
    {
        query_size = MAX_QUERY_LENGTH;
    }

    if(parent)
    {
        while(parent->next)
        {
            parent = parent->next;
        } 
        parent->next = (SQLQuery *)malloc(sizeof(SQLQuery));
        rval = parent->next;
    }
    else
    {
        rval = (SQLQuery *)malloc(sizeof(SQLQuery));
    }

    rval->val = (char *)malloc(query_size);
    rval->next = NULL;

    return rval;
}  

/*
 * Function: Database(Packet *, char * msg, void *arg)
 *
 * Purpose: Insert data into the database
 *
 * Arguments: p   => pointer to the current packet data struct 
 *            msg => pointer to the signature message
 *
 * Returns: void function
 *
 */
void Database(Packet *p, char *msg, void *arg, Event *event)
{
    DatabaseData *data = (DatabaseData *)arg;
    SQLQuery * query;
    SQLQuery * root;
    char * tmp, *tmp1, *tmp2, *tmp3;
    char * tmp_not_escaped;
    int i;
    char *select0, *select1, *insert0;
    unsigned int sig_id;
    extern OptTreeNode *otn_tmp;  /* rule node */ 
    ReferenceData *ds_ptr;
    PriorityData *class_ptr;
    int ref_system_id;
    unsigned int ref_id, class_id=0;

    query = NewQueryNode(NULL, 0);
    root = query;

    if(msg == NULL)
    {
        msg = "";
    }

    /*** Build the query for the Event Table ***/
    if(p != NULL)
    {
        tmp = GetTimestamp((time_t *)&p->pkth->ts.tv_sec, data->tz);
    }
    else
    {
        tmp = GetCurrentTimestamp();
    }
#ifdef ENABLE_MSSQL
    if(!strcasecmp(data->shared->dbtype,MSSQL))
    {
        /* SQL Server uses a date format which is slightly
         * different from the ISO-8601 standard generated
         * by GetTimestamp() and GetCurrentTimestamp().  We
         * need to convert from the ISO-8601 format of:
         *   "1998-01-25 23:59:59+14316557"
         * to the SQL Server format of:
         *   "1998-01-25 23:59:59.143"
         */
        if( tmp!=NULL && strlen(tmp)>=22 )
        {
            tmp[19] = '.';
            tmp[23] = '\0';
        }
    }
#endif

       /* Write the signature information 
        *  - Determine the ID # of the signature of this alert 
        */
       select0 = (char *) malloc (MAX_QUERY_LENGTH+1);
       if ( event->sig_rev == 0 ) 
       {
 	  if( event->sig_id == 0) 
          {
             snprintf(select0, MAX_QUERY_LENGTH, 
                      "SELECT sig_id FROM signature WHERE sig_name = '%s' AND"
                      " sig_rev is NULL AND sig_sid is NULL ", snort_escape_string(msg,data));
          }
          else 
          {
             snprintf(select0, MAX_QUERY_LENGTH, 
                      "SELECT sig_id FROM signature WHERE sig_name = '%s' AND"
                      " sig_rev is NULL AND sig_sid = %u ", 
                      snort_escape_string(msg, data), event->sig_id);
          }
       }
       else
       {
 	  if( event->sig_id == 0)
	  {
             snprintf(select0, MAX_QUERY_LENGTH,
                      "SELECT sig_id FROM signature WHERE sig_name = '%s' AND "
                      " sig_rev = %u AND sig_sid is NULL ",
                      snort_escape_string(msg, data), event->sig_rev);
          }
          else
	  {
             snprintf(select0, MAX_QUERY_LENGTH,
                      "SELECT sig_id FROM signature WHERE sig_name = '%s' AND "
                      " sig_rev = %u AND sig_sid = %u ",
                      snort_escape_string(msg, data), event->sig_rev, event->sig_id);
          }
       }

       sig_id = Select(select0, data);

       /* If this signature is detected for the first time
        *  - write the signature
        *  - write the signature's references, classification, priority, id,
        *                          revision number
        * Note: if a signature (identified with a unique text message, revision #) 
        *       initially is logged to the DB without references/classification, 
        *       but later they are added, this information will _not_ be 
        *       stored/updated unless the revision number is changed.
        *       This algorithm is used in order to prevent many DB SELECTs to
        *       verify their presence _every_ time the alert is triggered. 
        */
       if(sig_id == 0)
       {
         /* get classification and priority information  */
         if( otn_tmp )
         {
           class_ptr = (PriorityData *)otn_tmp->ds_list[PLUGIN_PRIORITY_NUMBER];

           if ( class_ptr )
	       {
             /* classification */
             if ( class_ptr->type )
	         {
	            /* Get the ID # of this classification */ 
                select1 = (char *) malloc (MAX_QUERY_LENGTH+1);

                snprintf(select1, MAX_QUERY_LENGTH, 
                         "SELECT sig_class_id FROM sig_class WHERE "
                         " sig_class_name = '%s'", snort_escape_string(class_ptr->type, data));
                class_id = Select(select1, data);

                if ( !class_id )
                {
                   insert0 = (char *) malloc (MAX_QUERY_LENGTH+1);
                   snprintf(insert0, MAX_QUERY_LENGTH,
                            "INSERT INTO sig_class (sig_class_name) VALUES "
                            "('%s')", snort_escape_string(class_ptr->type, data));
                   Insert(insert0, data);
                   free(insert0);
                   class_id = Select(select1, data);
                   if ( !class_id )
                    ErrorMessage("database: unable to write classification\n");
                 }
                free(select1);
	          }
	        }
          }

         insert0 = (char *) malloc (MAX_QUERY_LENGTH+1);
         tmp1 = (char *) malloc (MAX_QUERY_LENGTH+1);
         tmp2 = (char *) malloc (MAX_QUERY_LENGTH+1);
         tmp3 = (char *) malloc (MAX_QUERY_LENGTH/4+1);

         strcpy(tmp1, "sig_name");
         snprintf(tmp2, MAX_QUERY_LENGTH, "'%s'", msg);
         if ( class_id > 0 )
         {
            strcat(tmp1, ",sig_class_id");
            snprintf(tmp3, MAX_QUERY_LENGTH, ",%u", class_id);
            strcat(tmp2, tmp3);
         } 

         if ( event->priority > 0 )
         {
            strcat(tmp1, ",sig_priority");
            snprintf(tmp3, MAX_QUERY_LENGTH, ",%u", event->priority);
            strcat(tmp2, tmp3);
         }

         if ( event->sig_rev > 0 )
         {
            strcat(tmp1, ",sig_rev");
            snprintf(tmp3, MAX_QUERY_LENGTH, ",%u", event->sig_rev);
            strcat(tmp2, tmp3);
         }

         if ( event->sig_id > 0 )
         {
            strcat(tmp1, ",sig_sid");
            snprintf(tmp3, MAX_QUERY_LENGTH, ",%u", event->sig_id);
            strcat(tmp2, tmp3);
         }

         snprintf(insert0, MAX_QUERY_LENGTH,
                  "INSERT INTO signature (%s) VALUES (%s)",
                  tmp1, tmp2);

         Insert(insert0,data);
         free(insert0);
         free(tmp1);
         free(tmp2);
         free(tmp3);

         sig_id = Select(select0,data);
         if(sig_id == 0)
         {
           ErrorMessage("database: Problem inserting a new signature '%s'\n", msg);
         }
         free(select0);
       
         /* add the external rule references  */
         if(otn_tmp)
         {
           ds_ptr = (ReferenceData *)otn_tmp->ds_list[PLUGIN_REFERENCE_NUMBER];
           i = 1;

           while (ds_ptr)
           {
              /* Get the ID # of the reference from the DB */
              select0 = (char *) malloc (MAX_QUERY_LENGTH+1);
              insert0 = (char *) malloc (MAX_QUERY_LENGTH+1);

              /* Note: There is an underlying assumption that the SELECT
               *       will do a case-insensitive comparison.
               */
              snprintf(select0, MAX_QUERY_LENGTH, 
                       "SELECT ref_system_id FROM reference_system WHERE "
                       " ref_system_name = '%s'", snort_escape_string(ds_ptr->system, data));
              snprintf(insert0, MAX_QUERY_LENGTH,
                       "INSERT INTO reference_system (ref_system_name) "
                       "VALUES ('%s')", snort_escape_string(ds_ptr->system, data));
              ref_system_id = Select(select0, data);
              if ( ref_system_id == 0 )
              {
                 Insert(insert0, data);
                 ref_system_id = Select(select0, data);
              }

              free(select0);
              free(insert0);

              if ( ref_system_id > 0 )
              {
                 select0 = (char *) malloc (MAX_QUERY_LENGTH+1);
                 snprintf(select0, MAX_QUERY_LENGTH,
                         "SELECT ref_id FROM reference WHERE "
                         "ref_system_id = %d AND ref_tag = '%s'",
                          ref_system_id, snort_escape_string(ds_ptr->id, data));
                 ref_id = Select(select0, data);

                 /* If this reference is not in the database, write it */
                 if ( ref_id == 0 )
                 {
                   /* truncate the reference tag as necessary */
                   tmp1 = (char *) malloc (101);
                   if ( data->DBschema_version == 103 )
                     snprintf(tmp1, 20, "%s", ds_ptr->id);
                   else if ( data->DBschema_version >= 104 )
                     snprintf(tmp1, 100, "%s", ds_ptr->id);

                    insert0 = (char *) malloc (MAX_QUERY_LENGTH+1);
                    snprintf(insert0, MAX_QUERY_LENGTH,
                             "INSERT INTO reference (ref_system_id, ref_tag) VALUES "
                             "(%d, '%s')", ref_system_id, snort_escape_string(tmp1, data));
                    Insert(insert0, data);
                    ref_id = Select(select0, data);
                    free(insert0); 
                    free(tmp1);

                    if ( ref_id == 0 )
                    {
                       ErrorMessage("database: Unable to insert the alert reference into the DB\n");
                    }
                 }
                 free(select0);

                 insert0 = (char *) malloc (MAX_QUERY_LENGTH+1);
                 snprintf(insert0, MAX_QUERY_LENGTH,
                          "INSERT INTO sig_reference (sig_id, ref_seq, ref_id) "
                          "VALUES (%u, %d, %u)", sig_id, i, ref_id);
                 Insert(insert0, data);
                 free(insert0);
                 ++i;
              }
              else
              {
                 ErrorMessage("database: Unable to insert unknown reference tag ('%s') used in rule.\n", ds_ptr->system);
              }
 
              ds_ptr = ds_ptr->next;
            }
          }
     }
     else
        free(select0);

     snprintf(query->val, MAX_QUERY_LENGTH,
              "INSERT INTO event (sid,cid,signature,timestamp) VALUES "
              "('%u', '%u', '%u', '%s')",
              data->shared->sid, data->shared->cid, sig_id, tmp);     
    free(tmp); 

/* We do not log fragments! They are assumed to be handled 
    by the fragment reassembly pre-processor */

    if(p != NULL)
    {
        if((!p->frag_flag) && (p->iph)) 
        {
	  /* query = NewQueryNode(query, 0); */
            if(p->iph->ip_proto == IPPROTO_ICMP && p->icmph)
            {
                query = NewQueryNode(query, 0);
                /*** Build a query for the ICMP Header ***/
                if(data->detail)
                {
                    if(p->ext)
                    {
                        snprintf(query->val, MAX_QUERY_LENGTH, 
                                 "INSERT INTO icmphdr (sid, cid, icmp_type, icmp_code, "
                                 "icmp_csum, icmp_id, icmp_seq) "
                                 "VALUES ('%u','%u','%u','%u','%u','%u','%u')",
                                 data->shared->sid, data->shared->cid, p->icmph->type, p->icmph->code,
                                 ntohs(p->icmph->csum), ntohs(p->ext->id), ntohs(p->ext->seqno));
                    }
                    else
                    {
                        snprintf(query->val, MAX_QUERY_LENGTH, 
                                 "INSERT INTO icmphdr (sid, cid, icmp_type, icmp_code, "
                                 "icmp_csum) "
                                 "VALUES ('%u','%u','%u','%u','%u')",
                                 data->shared->sid, data->shared->cid, p->icmph->type, p->icmph->code,
                                 ntohs(p->icmph->csum));
                    }
                }
                else
                {
                    snprintf(query->val, MAX_QUERY_LENGTH, 
                             "INSERT INTO icmphdr (sid, cid, icmp_type, icmp_code) "
                             "VALUES ('%u','%u','%u','%u')",
                             data->shared->sid, data->shared->cid, p->icmph->type, p->icmph->code);
                }
            }
            else if(p->iph->ip_proto == IPPROTO_TCP && p->tcph)
            {
                query = NewQueryNode(query, 0);
                /*** Build a query for the TCP Header ***/
                if(data->detail)
                {
                    snprintf(query->val, MAX_QUERY_LENGTH, 
                             "INSERT INTO tcphdr "

                             "(sid, cid, tcp_sport, tcp_dport, tcp_seq,"
                             " tcp_ack, tcp_off, tcp_res, tcp_flags, tcp_win,"
                             " tcp_csum, tcp_urp) "

                             "VALUES ('%u','%u','%u','%u','%lu','%lu','%u',"
                             "'%u','%u','%u','%u','%u')",

                             data->shared->sid, data->shared->cid, ntohs(p->tcph->th_sport), 
                             ntohs(p->tcph->th_dport), (u_long)ntohl(p->tcph->th_seq),
                             (u_long)ntohl(p->tcph->th_ack), p->tcph->th_off, 
                             p->tcph->th_x2, p->tcph->th_flags, 
                             ntohs(p->tcph->th_win), ntohs(p->tcph->th_sum),
                             ntohs(p->tcph->th_urp));
                }
                else
                {
                    snprintf(query->val, MAX_QUERY_LENGTH, 
                             "INSERT INTO tcphdr "
                             "(sid,cid,tcp_sport,tcp_dport,tcp_flags) "
                             "VALUES ('%u','%u','%u','%u','%u')",
                             data->shared->sid, data->shared->cid, ntohs(p->tcph->th_sport), 
                             ntohs(p->tcph->th_dport), p->tcph->th_flags);
                }


                if(data->detail)
                {
                    /*** Build the query for TCP Options ***/
                    for(i=0; i < (int)(p->tcp_option_count); i++)
                    {
                        query = NewQueryNode(query, 0);
                        if((data->encoding == ENCODING_HEX) || (data->encoding == ENCODING_ASCII))
                        {
                            tmp = fasthex(p->tcp_options[i].data, (p->tcp_options[i].len==0)?0:(p->tcp_options[i].len-2)); 
                        }
                        else
                        {
                            tmp = base64(p->tcp_options[i].data, (p->tcp_options[i].len==0)?0:(p->tcp_options[i].len-2)); 
                        }
                        snprintf(query->val, MAX_QUERY_LENGTH, 
                                 "INSERT INTO opt "
                                 "(sid,cid,optid,opt_proto,opt_code,opt_len,opt_data) "
                                 "VALUES ('%u','%u','%u','%u','%u','%u','%s')",
                                 data->shared->sid, data->shared->cid, i, 6, p->tcp_options[i].code,
                                 p->tcp_options[i].len, tmp); 
                        free(tmp);
                    }
                }
            }
            else if(p->iph->ip_proto == IPPROTO_UDP && p->udph)
            {
                query = NewQueryNode(query, 0);
                /*** Build the query for the UDP Header ***/
                if(data->detail)
                {
                    snprintf(query->val, MAX_QUERY_LENGTH,
                             "INSERT INTO udphdr "
                             "(sid, cid, udp_sport, udp_dport, udp_len, udp_csum) "
                             "VALUES ('%u', '%u', '%u', '%u', '%u', '%u')",
                             data->shared->sid, data->shared->cid, ntohs(p->udph->uh_sport), 
                             ntohs(p->udph->uh_dport), ntohs(p->udph->uh_len),
                             ntohs(p->udph->uh_chk));
                }
                else
                {
                    snprintf(query->val, MAX_QUERY_LENGTH,
                             "INSERT INTO udphdr "
                             "(sid, cid, udp_sport, udp_dport) "
                             "VALUES ('%u', '%u', '%u', '%u')",
                             data->shared->sid, data->shared->cid, ntohs(p->udph->uh_sport), 
                             ntohs(p->udph->uh_dport));
                }
            }
        }   

        /*** Build the query for the IP Header ***/
        if ( p->iph )
        {
          query = NewQueryNode(query, 0);

          if(data->detail)
          {
            snprintf(query->val, MAX_QUERY_LENGTH, 

                     "INSERT INTO iphdr "
                     "(sid, cid, ip_src, ip_dst, ip_ver,"
                     "ip_hlen, ip_tos, ip_len, ip_id, ip_flags, ip_off,"
                     "ip_ttl, ip_proto, ip_csum) "

                     "VALUES ('%u','%u','%lu',"
                     "'%lu','%u',"
                     "'%u','%u','%u','%u','%u','%u',"
                     "'%u','%u','%u')",

                     data->shared->sid, data->shared->cid, (u_long)ntohl(p->iph->ip_src.s_addr), 
                     (u_long)ntohl(p->iph->ip_dst.s_addr), 
                     p->iph->ip_ver, p->iph->ip_hlen, 
                     p->iph->ip_tos, ntohs(p->iph->ip_len), ntohs(p->iph->ip_id), 
                     p->frag_flag, ntohs(p->frag_offset), p->iph->ip_ttl, 
                     p->iph->ip_proto, ntohs(p->iph->ip_csum));
          }
          else
          {
            snprintf(query->val, MAX_QUERY_LENGTH, 

                     "INSERT INTO iphdr "
                     "(sid, cid, ip_src, ip_dst, ip_proto) "

                     "VALUES ('%u','%u','%lu','%lu','%u')",

                     data->shared->sid, data->shared->cid, (u_long)ntohl(p->iph->ip_src.s_addr),
                     (u_long)ntohl(p->iph->ip_dst.s_addr), p->iph->ip_proto);
          }

          /*** Build querys for the IP Options ***/
          if(data->detail)
          {
            for(i=0 ; i < (int)(p->ip_option_count); i++)
            {
                if(&p->ip_options[i])
                {
                    query = NewQueryNode(query, 0);
                    if((data->encoding == ENCODING_HEX) || (data->encoding == ENCODING_ASCII))
                    {
                        tmp = fasthex(p->ip_options[i].data, p->ip_options[i].len); 
                    }
                    else
                    {
                        tmp = base64(p->ip_options[i].data, p->ip_options[i].len); 
                    }
                    snprintf(query->val, MAX_QUERY_LENGTH, 
                             "INSERT INTO opt "
                             "(sid,cid,optid,opt_proto,opt_code,opt_len,opt_data) "
                             "VALUES ('%u','%u','%u','%u','%u','%u','%s')",
                             data->shared->sid, data->shared->cid, i, 0, p->ip_options[i].code,
                             p->ip_options[i].len, tmp); 
                    free(tmp);
                }
            }
          }
        }

        /*** Build query for the payload ***/
        if ( p->data )
        {
          if(data->detail)
          {
            if(p->dsize)
            {
                query = NewQueryNode(query, p->dsize * 2 + MAX_QUERY_LENGTH);
                if(data->encoding == ENCODING_BASE64)
                {
                    tmp_not_escaped = base64(p->data, p->dsize);
                }
                else
                {
                    if(data->encoding == ENCODING_ASCII)
                    {
                        tmp_not_escaped = ascii(p->data, p->dsize);
                    }
                    else
                    {
                        tmp_not_escaped = fasthex(p->data, p->dsize);
                    }
                }

                tmp = snort_escape_string(tmp_not_escaped, data);

                snprintf(query->val, MAX_QUERY_LENGTH - 3, 
                         "INSERT INTO data "
                         "(sid,cid,data_payload) "
                         "VALUES ('%u','%u','%s",
                         data->shared->sid, data->shared->cid, tmp);
                strcat(query->val, "')");
                free (tmp);
                free (tmp_not_escaped);
            }
          }
        }
    }

    /* Execute the qureies */
    query = root;
    while(query)
    {
        Insert(query->val,data); 
        query = query->next;
    }

    FreeQueryNode(root); 

    data->shared->cid++;

    /* A Unixodbc bugfix */
#ifdef ENABLE_ODBC
    if(data->shared->cid == 600)
    {
        data->shared->cid = 601;
    }
#endif
}

/* Some of the code in this function is from the 
   mysql_real_escape_string() function distributed with mysql.

   Those portions of this function remain
   Copyright (C) 2000 MySQL AB & MySQL Finland AB & TCX DataKonsult AB

   We needed a more general case that was not MySQL specific so there
   were small modifications made to the mysql_real_escape_string() 
   function. */

char * snort_escape_string(char * from, DatabaseData * data)
{
    char * to;
    char * to_start;
    char * end; 
    int from_length;

    from_length = (int)strlen(from);

    to = (char *)malloc(strlen(from) * 2 + 1);
    to_start = to;
#ifdef ENABLE_ORACLE
    if (!strcasecmp(data->shared->dbtype,ORACLE))
    {
      for (end=from+from_length; from != end; from++)
      {
        switch(*from)
        {
          case '\n':                               /* Must be escaped for logs */
            *to++= '\\';
            *to++= 'n';
            break;
          case '\r':
            *to++= '\\';
            *to++= 'r';
            break;
          case '\'':
            *to++= '\'';
            *to++= '\'';
            break;
          case '\032':                     /* This gives problems on Win32 */
            *to++= '\\';
            *to++= 'Z';
            break;
          default:
            *to++= *from;
        }
      }
    }
    else
#endif
    {
      for(end=from+from_length; from != end; from++)
      {
        switch(*from)
        {
          case 0:             /* Must be escaped for 'mysql' */
            *to++= '\\';
            *to++= '0';
            break;
          case '\n':              /* Must be escaped for logs */
            *to++= '\\';
            *to++= 'n';
            break;
          case '\r':
            *to++= '\\';
            *to++= 'r';
            break;
          case '\\':
            *to++= '\\';
            *to++= '\\';
            break;
          case '\'':
            *to++= '\\';
            *to++= '\'';
            break;
          case '"':               /* Better safe than sorry */
            *to++= '\\';
            *to++= '"';
            break;
          case '\032':            /* This gives problems on Win32 */
            *to++= '\\';
            *to++= 'Z';
            break;
          default:
            *to++= *from; 
        }
      }
    }
    *to=0;
    return(char *)to_start;
}

/* Function: CheckDBVersion(DatabaseData * data)
 *
 * Purpose: To determine the version number of the underlying DB schema
 *
 * Arguments: database information
 *
 * Returns: version number of the schema
 */
int CheckDBVersion(DatabaseData * data)
{
  char *select0;
  int schema_version;

  select0 = (char *) malloc (MAX_QUERY_LENGTH+1);
  snprintf(select0, MAX_QUERY_LENGTH, 
#ifndef ENABLE_MSSQL
           "SELECT vseq FROM schema");
#else
           /* "schema" is a keyword in SQL Server, so quote it with square brackets */
           "SELECT vseq FROM [schema]");
#endif

  schema_version = Select(select0,data);
  free(select0);

  return schema_version;
}

/* Function: Insert(char * query, DatabaseData * data)
 *
 * Purpose: Database independent function for SQL inserts
 * 
 * Arguments: query (An SQL insert)
 *
 * Returns: 1 if successful, 0 if fail
 */
int Insert(char * query, DatabaseData * data)
{
    int result = 0;
#ifdef ENABLE_MYSQL
    char * modified_query;
    char * query_ptr;
#endif

#ifdef ENABLE_POSTGRESQL
    if(!strcasecmp(data->shared->dbtype,POSTGRESQL))
    {
        data->p_result = PQexec(data->p_connection,query);
        if(!(PQresultStatus(data->p_result) != PGRES_COMMAND_OK))
        {
            result = 1;
        }
        else
        {
            if(PQerrorMessage(data->p_connection)[0] != '\0')
            {
                ErrorMessage("database: postgresql_error: %s\n", PQerrorMessage(data->p_connection));
            }
        } 
    }
#endif

#ifdef ENABLE_MYSQL
    if(!strcasecmp(data->shared->dbtype,MYSQL))
    {
        
        if(!strncasecmp(query, STANDARD_INSERT, strlen(STANDARD_INSERT)))
        {
           modified_query = (char *)malloc(strlen(query) + 10);
           strncpy(modified_query, MYSQL_INSERT, strlen(MYSQL_INSERT)+1);
           query_ptr = query + strlen(STANDARD_INSERT);
           strncat(modified_query, query_ptr, strlen(query_ptr)+1);
           strncpy(query,modified_query,strlen(modified_query)+1);
           free(modified_query);
        } 
        
        if(!(mysql_query(data->m_sock,query)))
        {
            result = 1;
        }
        else
        {
            if(mysql_errno(data->m_sock))
            {
              ErrorMessage("database: mysql_error: %s\nSQL=%s\n", 
                           mysql_error(data->m_sock), query);
            }
        }
    }
#endif

#ifdef ENABLE_ODBC
    if(!strcasecmp(data->shared->dbtype,ODBC))
    {
        if(SQLAllocStmt(data->u_connection, &data->u_statement) == SQL_SUCCESS)
            if(SQLPrepare(data->u_statement, query, SQL_NTS) == SQL_SUCCESS)
                if(SQLExecute(data->u_statement) == SQL_SUCCESS)
                    result = 1;
    }
#endif

#ifdef ENABLE_ORACLE
    if(!strcasecmp(data->shared->dbtype,ORACLE))
    {
        if (OCIStmtPrepare(data->o_statement, data->o_error, query, strlen(query), OCI_NTV_SYNTAX, OCI_DEFAULT) || 
	    OCIStmtExecute(data->o_servicecontext, data->o_statement, data->o_error, 1,  0, NULL, NULL, OCI_COMMIT_ON_SUCCESS))
        {
	    OCIErrorGet(data->o_error, 1, NULL, &data->o_errorcode, data->o_errormsg, sizeof(data->o_errormsg), OCI_HTYPE_ERROR);
	    ErrorMessage("database: oracle_error: %s\n", data->o_errormsg);
        } 
	else 
        {
  	    result = 1;
	}
    }
#endif

#ifdef ENABLE_MSSQL
    if(!strcasecmp(data->shared->dbtype,MSSQL))
    {
        SAVESTATEMENT(query);
        dbfreebuf(data->ms_dbproc);
        if( dbcmd(data->ms_dbproc, query) == SUCCEED )
            if( dbsqlexec(data->ms_dbproc) == SUCCEED )
                if( dbresults(data->ms_dbproc) == SUCCEED )
                    while (dbnextrow(data->ms_dbproc) != NO_MORE_ROWS)
                    {
                        result = (int)data->ms_col;
                    }
        CLEARSTATEMENT();
    }
#endif

#ifdef DEBUG
    if(result)
    {
        printf("database(debug): (%s) executed\n", query);
    }
    else
    {
        printf("database(debug): (%s) failed\n", query);
    }
#endif

    return result;
}

/* Function: Select(char * query, DatabaeData * data)
 *
 * Purpose: Database independent function for SQL selects that 
 *          return a non zero int
 * 
 * Arguments: query (An SQL insert)
 *
 * Returns: result of query if successful, 0 if fail
 */
int Select(char * query, DatabaseData * data)
{
    int result = 0;

#ifdef ENABLE_POSTGRESQL
    if(!strcasecmp(data->shared->dbtype,POSTGRESQL))
    {
        data->p_result = PQexec(data->p_connection,query);
        if((PQresultStatus(data->p_result) == PGRES_TUPLES_OK))
        {
            if(PQntuples(data->p_result))
            {
                if((PQntuples(data->p_result)) > 1)
                {
                    ErrorMessage("database: warning (%s) returned more than one result\n", query);
                    result = 0;
                }
                else
                {
                    result = atoi(PQgetvalue(data->p_result,0,0));
                } 
            }
        }
        if(!result)
        {
            if(PQerrorMessage(data->p_connection)[0] != '\0')
            {
                ErrorMessage("database: postgresql_error: %s\n",PQerrorMessage(data->p_connection));
            }
        }
    }
#endif

#ifdef ENABLE_MYSQL
    if(!strcasecmp(data->shared->dbtype,MYSQL))
    {
        if(mysql_query(data->m_sock,query))
        {
            result = 0;
        }
        else
        {
            if(!(data->m_result = mysql_use_result(data->m_sock)))
            {
                result = 0;
            }
            else
            {
                if((data->m_row = mysql_fetch_row(data->m_result)))
                {
                    if(data->m_row[0] != NULL)
                    {
                        result = atoi(data->m_row[0]);
                    }
                }
            }
            mysql_free_result(data->m_result);
        }
        if(!result)
        {
            if(mysql_errno(data->m_sock))
            {
                ErrorMessage("database: mysql_error: %s\n", mysql_error(data->m_sock));
            }
        }
    }
#endif

#ifdef ENABLE_ODBC
    if(!strcasecmp(data->shared->dbtype,ODBC))
    {
        if(SQLAllocStmt(data->u_connection, &data->u_statement) == SQL_SUCCESS)
            if(SQLPrepare(data->u_statement, query, SQL_NTS) == SQL_SUCCESS)
                if(SQLExecute(data->u_statement) == SQL_SUCCESS)
                    if(SQLRowCount(data->u_statement, &data->u_rows) == SQL_SUCCESS)
                        if(data->u_rows)
                        {
                            if(data->u_rows > 1)
                            {
                                ErrorMessage("database: warning (%s) returned more than one result\n", query);
                                result = 0;
                            }
                            else
                            {
                                if(SQLFetch(data->u_statement) == SQL_SUCCESS)
                                    if(SQLGetData(data->u_statement,1,SQL_INTEGER,&data->u_col,
                                                  sizeof(data->u_col), NULL) == SQL_SUCCESS)
                                        result = (int)data->u_col;
                            }
                        }
    }
#endif

#ifdef ENABLE_ORACLE
    if(!strcasecmp(data->shared->dbtype,ORACLE))
    {
        if (OCIStmtPrepare(data->o_statement, data->o_error, query, strlen(query), OCI_NTV_SYNTAX, OCI_DEFAULT) ||
	    OCIStmtExecute(data->o_servicecontext, data->o_statement, data->o_error, 0, 0, NULL, NULL, OCI_DEFAULT) ||
	    OCIDefineByPos (data->o_statement, &data->o_define, data->o_error, 1, &result, sizeof(result), SQLT_INT, 0, 0, 0, OCI_DEFAULT) ||
	    OCIStmtFetch (data->o_statement, data->o_error, 1, OCI_FETCH_NEXT, OCI_DEFAULT))
	{
	    OCIErrorGet(data->o_error, 1, NULL, &data->o_errorcode, data->o_errormsg, sizeof(data->o_errormsg), OCI_HTYPE_ERROR);
	    ErrorMessage("database: oracle_error: %s\n", data->o_errormsg);
	}
    }
#endif

#ifdef ENABLE_MSSQL
    if(!strcasecmp(data->shared->dbtype,MSSQL))
    {
        SAVESTATEMENT(query);
        dbfreebuf(data->ms_dbproc);
        if( dbcmd(data->ms_dbproc, query) == SUCCEED )
            if( dbsqlexec(data->ms_dbproc) == SUCCEED )
                if( dbresults(data->ms_dbproc) == SUCCEED )
                    if( dbbind(data->ms_dbproc, 1, INTBIND, (DBINT) 0, (BYTE *) &data->ms_col) == SUCCEED )
                        while (dbnextrow(data->ms_dbproc) != NO_MORE_ROWS)
                        {
                            result = (int)data->ms_col;
                        }
        CLEARSTATEMENT();
    }
#endif

#ifdef DEBUG
    if(result)
    {
        printf("database(debug): (%s) returned %u\n", query, result);
    }
    else
    {
        printf("database(debug): (%s) failed\n", query);
    }
#endif

    return result;
}


/* Function: Connect(DatabaseData * data)
 *
 * Purpose: Database independent function to initiate a database 
 *          connection
 */

void Connect(DatabaseData * data)
{
#ifdef ENABLE_MYSQL
    int x; 
#endif

#ifdef ENABLE_POSTGRESQL
    if(!strcasecmp(data->shared->dbtype,POSTGRESQL))
    {
        data->p_connection = PQsetdbLogin(data->shared->host,data->port,NULL,NULL,data->shared->dbname,data->user,data->password);
        if(PQstatus(data->p_connection) == CONNECTION_BAD)
        {
            PQfinish(data->p_connection);
            FatalError("database: Connection to database '%s' failed\n", data->shared->dbname);
        }
    }
#endif

#ifdef ENABLE_MYSQL
    if(!strcasecmp(data->shared->dbtype,MYSQL))
    {
        data->m_sock = mysql_init(NULL);
        if(data->m_sock == NULL)
        {
            FatalError("database: Connection to database '%s' failed\n", data->shared->dbname);
        }
        if(data->port != NULL)
        {
            x = atoi(data->port);
        }
        else
        {
            x = 0;
        }
        if(mysql_real_connect(data->m_sock, data->shared->host, data->user, data->password, data->shared->dbname, x, NULL, 0) == 0)
        {
            if(mysql_errno(data->m_sock))
            {
                FatalError("database: mysql_error: %s\n", mysql_error(data->m_sock));
            }
            FatalError("database: Failed to logon to database '%s'\n", data->shared->dbname);
        }
    }
#endif

#ifdef ENABLE_ODBC
    if(!strcasecmp(data->shared->dbtype,ODBC))
    {
        if(!(SQLAllocEnv(&data->u_handle) == SQL_SUCCESS))
        {
            FatalError("database: unable to allocate ODBC environment\n");
        }
        if(!(SQLAllocConnect(data->u_handle, &data->u_connection) ==
             SQL_SUCCESS))
        {
            FatalError("database: unable to allocate ODBC connection handle\n");
        }
        if(!(SQLConnect(data->u_connection, data->shared->dbname, SQL_NTS, data->user, SQL_NTS, data->password, SQL_NTS) == SQL_SUCCESS))
        {
            FatalError("database: ODBC unable to connect\n");
        }
    }
#endif

#ifdef ENABLE_ORACLE
    if(!strcasecmp(data->shared->dbtype,ORACLE))
    {
      if (OCIInitialize(OCI_DEFAULT, NULL, NULL, NULL, NULL) ||
         OCIEnvInit(&data->o_environment, OCI_DEFAULT, 0, NULL) ||
         OCIEnvInit(&data->o_environment, OCI_DEFAULT, 0, NULL) ||
         OCIHandleAlloc(data->o_environment, (dvoid **)&data->o_error, OCI_HTYPE_ERROR, (size_t) 0, NULL) ||
         OCILogon(data->o_environment, data->o_error, &data->o_servicecontext,
data->user, strlen(data->user), data->password, strlen(data->password), data->shared->dbname, strlen(data->shared->dbname)) ||
         OCIHandleAlloc(data->o_environment, (dvoid **)&data->o_statement, OCI_HTYPE_STMT, 0, NULL))
      {
         OCIErrorGet(data->o_error, 1, NULL, &data->o_errorcode, data->o_errormsg, sizeof(data->o_errormsg), OCI_HTYPE_ERROR);
         ErrorMessage("database: oracle_error: %s\n", data->o_errormsg);
         FatalError("database: Connection to database '%s' failed\n", data->shared->dbname);
      }
    }
#endif

#ifdef ENABLE_MSSQL
    if(!strcasecmp(data->shared->dbtype,MSSQL))
    {
        CLEARSTATEMENT();
        dberrhandle(mssql_err_handler);
        dbmsghandle(mssql_msg_handler);

        if( dbinit() != NULL )
        {
            data->ms_login = dblogin();
            if( data->ms_login == NULL )
            {
                FatalError("database: Failed to allocate login structure\n");
            }
            /* Set up some informational values which are stored with the connection */
            DBSETLUSER (data->ms_login, data->user);
            DBSETLPWD  (data->ms_login, data->password);
            DBSETLAPP  (data->ms_login, "snort");
  
            data->ms_dbproc = dbopen(data->ms_login, data->shared->host);
            if( data->ms_dbproc == NULL )
            {
                FatalError("database: Failed to logon to host '%s'\n", data->shared->host);
            }
            else
            {
                if( dbuse( data->ms_dbproc, data->shared->dbname ) != SUCCEED )
                {
                    FatalError("database: Unable to change context to database '%s'\n", data->shared->dbname);
                }
            }
        }
        else
        {
            FatalError("database: Connection to database '%s' failed\n", data->shared->dbname);
        }
        CLEARSTATEMENT();
    }
#endif
}

/* Function: Disconnect(DatabaseData * data)
 *
 * Purpose: Database independent function to close a connection
 */

void Disconnect(DatabaseData * data)
{
    if( !pv.quiet_flag ) printf("database: Closing %s connection to database \"%s\"\n", data->shared->dbtype, data->shared->dbname);

    if(data)
    {
#ifdef ENABLE_POSTGRESQL
        if(!strcasecmp(data->shared->dbtype,POSTGRESQL))
        {
            if(data->p_connection) PQfinish(data->p_connection);
        }
#endif

#ifdef ENABLE_MYSQL
        if(!strcasecmp(data->shared->dbtype,MYSQL))
        {
            if(data->m_sock) mysql_close(data->m_sock);
        }
#endif

#ifdef ENABLE_ODBC
        if(!strcasecmp(data->shared->dbtype,ODBC))
        {
            if(data->u_handle)
            {
                SQLDisconnect(data->u_connection); 
                SQLFreeHandle(SQL_HANDLE_ENV, data->u_handle); 
            }
        }
#endif

#ifdef ENABLE_MSSQL
        if(!strcasecmp(data->shared->dbtype,MSSQL))
        {
            CLEARSTATEMENT();
            if( data->ms_dbproc != NULL )
            {
                dbfreelogin(data->ms_login);
                data->ms_login = NULL;
                dbclose(data->ms_dbproc);
                data->ms_dbproc = NULL;
            }
        }
#endif
    }
}

void DatabasePrintUsage()
{
    puts("\nUSAGE: database plugin\n");

    puts(" output database: [log | alert], [type of database], [parameter list]\n");
    puts(" [log | alert] selects whether the plugin will use the alert or");
    puts(" log facility.\n");

    puts(" For the first argument, you must supply the type of database.");
    puts(" The possible values are mysql, postgresql, unixodbc, oracle and");
    puts(" mssql (oracle support is beta in snort release 1.7, and mssql");
    puts(" support is beta in snort release 1.8).\n");

    puts(" The parameter list consists of key value pairs. The proper");
    puts(" format is a list of key=value pairs each separated a space.\n");

    puts(" The only parameter that is absolutely necessary is \"dbname\"."); 
    puts(" All other parameters are optional but may be necessary");
    puts(" depending on how you have configured your RDBMS.\n");

    puts(" dbname - the name of the database you are connecting to\n"); 

    puts(" host - the host the RDBMS is on\n");

    puts(" port - the port number the RDBMS is listening on\n"); 

    puts(" user - connect to the database as this user\n");

    puts(" password - the password for given user\n");

    puts(" sensor_name - specify your own name for this snort sensor. If you");
    puts("        do not specify a name one will be generated automatically\n");

    puts(" encoding - specify a data encoding type (hex, base64, or ascii)\n");

    puts(" detail - specify a detail level (full or fast)\n");

    puts(" FOR EXAMPLE:");
    puts(" The configuration I am currently using is MySQL with the database");
    puts(" name of \"snort\". The user \"jed@localhost\" has INSERT and SELECT");
    puts(" privileges on the \"snort\" database and does not require a password.");
    puts(" The following line enables snort to log to this database.\n");

    puts(" output database: log, mysql, dbname=snort user=jed host=localhost\n");
}

void SpoDatabaseCleanExitFunction(int signal, void *arg)
{
    DatabaseData *data = (DatabaseData *)arg;
#ifdef DEBUG
    printf("database(debug): entered SpoDatabaseCleanExitFunction\n");
#endif
    Disconnect(data); 
    if(data) 
		free(data);
	if(--instances == 0)
		FreeSharedDataList();
}

void SpoDatabaseRestartFunction(int signal, void *arg)
{
    DatabaseData *data = (DatabaseData *)arg;
#ifdef DEBUG
    printf("database(debug): entered SpoDatabaseRestartFunction\n");
#endif
    Disconnect(data);
    if(data) 
		free(data);
	if(--instances == 0)
		FreeSharedDataList();
}

void FreeSharedDataList()
{
	SharedDatabaseDataNode *current;
	while(sharedDataList != NULL)
	{ 
		current = sharedDataList;
		free(current->data);
		sharedDataList = current->next;
		free(current);
	}
}



#ifdef ENABLE_MSSQL
/*
 * The functions mssql_err_handler() and mssql_msg_handler() are callbacks that are registered
 * when we connect to SQL Server.  They get called whenever SQL Server issues errors or messages.
 * This should only occur whenever an error has occurred, or when the connection switches to
 * a different database within the server.
 */
static int mssql_err_handler(PDBPROCESS dbproc, int severity, int dberr, int oserr, LPCSTR dberrstr, LPCSTR oserrstr)
{
    int retval;
    ErrorMessage("database: DB-Library error:\n\t%s\n", dberrstr);

    if ( severity == EXCOMM && (oserr != DBNOERR || oserrstr) )
        ErrorMessage("database: Net-Lib error %d:  %s\n", oserr, oserrstr);
    if ( oserr != DBNOERR )
        ErrorMessage("database: Operating-system error:\n\t%s\n", oserrstr);
#ifdef ENABLE_MSSQL_DEBUG
    if( strlen(g_CurrentStatement) > 0 )
        ErrorMessage("database:  The above error was caused by the following statement:\n%s\n", g_CurrentStatement);
#endif
    if ( (dbproc == NULL) || DBDEAD(dbproc) )
        retval = INT_EXIT;
    else
        retval = INT_CANCEL;
    return(retval);
}


static int mssql_msg_handler(PDBPROCESS dbproc, DBINT msgno, int msgstate, int severity, LPCSTR msgtext, LPCSTR srvname, LPCSTR procname, DBUSMALLINT line)
{
    ErrorMessage("database: SQL Server message %ld, state %d, severity %d: \n\t%s\n",
                 msgno, msgstate, severity, msgtext);
    if ( (srvname!=NULL) && strlen(srvname)!=0 )
        ErrorMessage("Server '%s', ", srvname);
    if ( (procname!=NULL) && strlen(procname)!=0 )
        ErrorMessage("Procedure '%s', ", procname);
    if (line !=0) 
        ErrorMessage("Line %d", line);
    ErrorMessage("\n");
#ifdef ENABLE_MSSQL_DEBUG
    if( strlen(g_CurrentStatement) > 0 )
        ErrorMessage("database:  The above error was caused by the following statement:\n%s\n", g_CurrentStatement);
#endif

    return(0);
}
#endif
