.packageName <- "exactLoglinTest"
build.mcx.obj <- function(formula,
                          data,
                          stat = gof,
                          dens = hyper,
                          nosim = 10 ^ 3,
                          method = "bab",
                          savechain = FALSE,
                          tdf = 3,
                          maxiter = nosim,
                          p = NULL,
                          batchsize = NULL){
  ##the observed test statistic value
  glm.fit <- glm(formula,
                 family = poisson,
                 x = TRUE,
                 y = TRUE,
                 data = data)
  
  mu.hat <- fitted(glm.fit)
  y <- glm.fit$y
  x <- glm.fit$x
  r <- glm.fit$qr$rank
  
  ##gets rid of redundant rows
  if (r < nrow(x))
    x <- x[, glm.fit$qr$pivot[1 : r]]

  errorcheck(y, x, stat, dens, nosim, method, savechain, tdf, maxiter, p, batchsize)
  
  n <- length(y)
  
  ##the following reorder x
  temp <- qr(t(x))
  n1 <- n - temp$rank
  ord <- rev(temp$pivot)
  x <- x[ord,]
  y <- y[ord]
  mu.hat <- mu.hat[ord]
  s <- t(x) %*% y
  x1 <- x[1 : n1,]
  ##we only need the inverse of x2
  x2invt <- t(solve(x[(n1 + 1) : n,]))
  dobs <- stat(y = y, mu = mu.hat, rowlabels = FALSE)
  
  ##get the conditional means and variances required
  ##by bab and cab
  mu.hat1 <- mu.hat[1 : n1]
  ctmp <- rbind(cbind(diag(rep(1, n1)), matrix(0, nrow = n1, ncol = n - n1)), t(x))
  v <- ctmp %*% diag(mu.hat) %*% t(ctmp)
  temp <-   v[1 : n1, (n1 + 1) : n]
  condv1 <- v[1 : n1, 1 : n1] - temp %*% solve(v[(n1 + 1) : n, (n1 + 1) : n]) %*% t(temp)
  
  ##creates the object required as the input to update
  args <- list(conde1 = mu.hat1,
               condv1 = condv1,
               dens = dens,
               dobs = dobs,
               mu.hat = mu.hat, 
               n = n,
               n1 = n1,
               nosim = nosim,
               s = s,
               stat = stat,
               tdf = tdf,
               x = x,
               x1 = x1,
               x2invt = x2invt,
               y = y,
               ord = ord,
               glm.fit = glm.fit)
  if (method == "bab") {
    args$maxiter <- maxiter
    class(args) <- c("bab")
  }
  else if (method == "cab"){
    args$p <- p
    args$batchsize <- batchsize
    class(args) <- c("cab")
  }
  return(args)
}






#y is a starting value
#x is the design matrix
errorcheck <- function(y, x, stat, dens, nosim, method, savechain, tdf, maxiter, p, batchsize){
  if ((!is.real(y)) & (!is.integer(y)))
    stop("y must be real or integer valued")
  else {
    if (any(y < 0))
      stop("y must be positive")
    if (any(round(y) != y))
      stop("round(y) must equal y")
  }

  if (!is.matrix(x))
    stop("x must a matrix")
  else if ((!is.real(x)) & (!is.integer(x)))
    stop("x must be real or integer valued")
  else if (qr(x)$rank > dim(x)[2])
    stop("Rank(x) <= number of rows")
  else if (dim(x)[2] < 2)
    stop("Model must have more than 2 parameters")
  else if (dim(x)[1] < dim(x)[2])
    stop("Model has more parameters than datapoints")
  else if (qr(x)$rank != dim(x)[2])
    stop("Design matrix must be of full rank")
  
  if (!is.function(stat)) stop("stat must be a function")
  
  if (!is.function(dens)) stop("dens must be a function")

  if ((!is.real(nosim)) & (!is.integer(nosim))) stop("nosim must be real or integer valued")
  else if (nosim <= 0) stop("nosim < 0 not allowed")

  if (method != "bab" & method != "cab") stop("method must be either cab or bab")

  if ((!is.real(tdf)) & (!is.integer(tdf))) stop("tdf must be real or integer valued")
  else if (tdf <= 0) stop("tdf < 0 not allowed")

  if (!is.null(maxiter) & !is.real(maxiter) & !is.integer(maxiter))
    stop("maxiter must be null, real or integer valued")
  else if (method == "bab" & is.null(maxiter)) stop("maxiter must be specified if method = bab")
  else if (maxiter < nosim) stop("maxiter >= nosim")
   
  if (!is.null(p) & !is.real(p) & !is.integer(p))
    stop("p must be null, real or integer valued")
  else if (method == "cab" & is.null(p)) stop("p must be specified if method = cab")
  else if (!is.null(p))
    if (p < 0 | p > 1) stop("p must be between 0 and 1")

  if (!is.null(batchsize) & !is.real(batchsize) & !is.integer(batchsize))
    stop("batchsize must be null, real or integer valued")
  else if (method == "cab"){ 
    if (is.null(batchsize))
      stop("batchsize must be specified if method = bab") 
    else if (batchsize <= 0) stop("batchsize must be > 0")
    else if (nosim / batchsize < 1) stop("batches are too large")
  }
}
gof <- function(y = NULL, mu = NULL, rowlabels = F){
  if (rowlabels) c("deviance", "Pearson")
  else {
    temp <- y != 0
    c(2 * sum(y[temp] * log(y[temp] / mu[temp])), sum((y - mu) ^ 2 / mu))
  }
}
hyper <- function(y)
  sum(-lgamma(y + 1))
mcexact <- function(formula,
                    data,
                    stat = gof,
                    dens = hyper,
                    nosim = 10 ^ 3,
                    method = "bab",
                    savechain = FALSE,
                    tdf = 3,
                    maxiter = nosim,
                    p = NULL,
                    batchsize = NULL){  
  args <- build.mcx.obj(formula,
                        data,
                        stat,
                        dens,
                        nosim,
                        method,
                        savechain,
                        tdf,
                        maxiter,
                        p,
                        batchsize)
  update(args, savechain = savechain)
}

.First.lib <- function(lib, pkg){
  library.dynam("exactLoglinTest", pkg, lib)
}
print.bab <- function(x, ...){
  rval <- as.data.frame(rbind(x$dobs, x$phat, x$mcse))
  rownames(rval) <- c("observed.stat", "pvalue", "mcse")
  colnames(rval) <- x$stat(rowlabels = TRUE)
  rval
}

print.cab <- function(x, ...){
  rval <- as.data.frame(rbind(x$dobs, x$phat, x$mcse))
  rownames(rval) <- c("observed.stat", "pvalue", "mcse")
  colnames(rval) <- x$stat(rowlabels = TRUE)
  rval
}
rounded.tprob <- function(y, m, s, df) {
  t1 <- (y + .5 - m) / sqrt(s)
  t2 <- (y - .5 - m) / sqrt(s)
  sum(log(pt(t1, df) - pt(t2, df)))
}
simtable.bab <-  function(args, nosim = NULL, maxiter = NULL){
  if (!is.null(nosim)) args$nosim <- nosim
  if (!is.null(maxiter)) args$maxiter <- maxiter
  if (args$nosim >  args$maxiter) {
    warning("update.bab, nosim > maxiter, setting maxiter = nosim")
    args$maxiter <- args$nosim
  }
  nocol <- length(args$y)
  chain <- matrix(0, args$nosim, nocol + 1)
  i <- 1
  j <- 0
  while ((i <= args$nosim) & (j < args$maxiter)){
     j <- j + 1
     shuffle <- sample(1 : args$n1)
     conde1.permute <-  args$conde1[shuffle]
     ##this is equal to P %*% condv1 %*% t(P)
     condv1.permute <-  args$condv1[shuffle, shuffle]
     ##get and unshuffle y1.new
     ##note conde1.permute and condv1.permute
     ##are now the sequential means and variances
     y1.new.permute <- .Call("multinormfull",
                             conde1.permute,
                             condv1.permute,
                             args$tdf,
                             PACKAGE="exactLoglinTest")
     y1.new <- y1.new.permute[order(shuffle)]
     y2.new <- args$x2invt %*% (args$s - t(args$x1) %*% y1.new)
     ##though technically y.new has to be an interger we
     ##coerce here since the calculation
     ##of y.new is done as double
     y.new <- round(c(y1.new, y2.new))
     if (all(y.new >= 0)){
       d <- y.new
       ##importance weights on the log scale
       w <- args$dens(y.new) - rounded.tprob(y1.new.permute,
                                             conde1.permute,
                                             diag(condv1.permute),
                                             args$tdf)
       ##the following subtracts off a constant from
       ##all of the importance weights
       ##the constant is the weight of the first
       ##simulated table
       chain[i,] <- c(d, w)
       i <- i + 1
     }
   }
  if (i == 1)
    warning("No valid tables found")
  else {
    return(chain[1 : (i - 1),order(args$ord)])
  }
}







simtable.cab <- function(args, nosim = NULL, p = NULL, y.start = NULL){
  ##error checking and initializing
  if (!is.null(p)){
    if (!is.real(p)) stop("p must be real valued")
    else if ((p < 0) | (p > 1)) stop("p must be in [0,1]")
  }
  if (!is.null(nosim)) args$nosim <- nosim
  if (is.null(y.start)) y.start <- args$y
  else if (t(args$x) %*% y.start != args$s)
    stop("invalid starting value")
           
  nocol <- length(args$y)
  chain <- matrix(0, args$nosim, nocol)

  y.old <- y.start
  y1.old <- y.start[1 : args$n1]  
  for (i in 1 : args$nosim){    
    shuffle <- sample(1 : args$n1)
    conde1.permute <-  args$conde1[shuffle]
    condv1.permute <-  args$condv1[shuffle, shuffle]
    ##k is the number of elements to be left fixed
    k <- rbinom(1, args$n1 - 1, args$p)
    ##separate y1 into those that stay the same and
    ##those that get updated
    y1.old.permute <- y1.old[shuffle]
    if (k > 0)
      staysfixed <- y1.old.permute[1 : k]
    else
       staysfixed <- NULL
    getsupdated <- y1.old.permute[(k + 1) : args$n1]
    ##multinorm calculates the required mean for
    ##going backwards
    temp <- .Call("multinorm",
                  conde1.permute,
                  condv1.permute,
                  as.real(staysfixed),
                  as.real(y1.old.permute),
                  args$tdf,
                  as.integer(k),
                  PACKAGE="exactLoglinTest")
    ##the ones that got updated
    y1.new.permute <- temp[[1]]
    conde1.old.permute <- temp[[2]]
    changed <- y1.new.permute[(k + 1) : args$n1]
    tf2 <- tf3 <- FALSE
    tf1 <- all(changed >= 0)
    if (tf1){
      y1.new <- y1.new.permute[order(shuffle)]
      y2.new <- round(args$x2invt %*% (args$s - t(args$x1) %*% y1.new))
      tf2 <- all(y2.new >= 0) 
      if (tf2){
        y.new <- c(y1.new, y2.new)
        target.new <- args$dens(y.new)
        target.old <- args$dens(y.old)
        mean.new <- conde1.permute[(k + 1) : args$n1]
        mean.old <- conde1.old.permute[(k + 1) : args$n1]
        var.old.new <- diag(condv1.permute)[(k + 1) : args$n1]
        cand.new <- rounded.tprob(changed    , mean.new, var.old.new, args$tdf)
        cand.old <- rounded.tprob(getsupdated, mean.old, var.old.new, args$tdf) 
        w <- target.new +  cand.old - target.old - cand.new
        tf3 <- log(runif(1)) <=  w
      }
    }
    if (all(tf1, tf2, tf3)){
      y.old <- y.new
      y1.old <- y1.new
      args$mhap <- args$mhap + 1
    }
    else {
      y.new <- y.old
      ##not necessary but just to remind you
      ##y.old remains y.old for the next iteration
      ##y1.old remains y1.old for the next iteration
    }
    chain[i,] <- y.new
  }
  return(chain[,order(args$ord)])
}

simulateConditional <- function(formula,
                                 data,
                                 dens = hyper,
                                 nosim = 10 ^ 3,
                                 method = "bab",
                                 tdf = 3,
                                 maxiter = nosim,
                                 p = NULL,
                                 y.start = NULL){
  args <- build.mcx.obj(formula = formula,
                        data = data,
                        dens = dens,
                        nosim = nosim,
                        method = method,
                        tdf = tdf,
                        maxiter = maxiter,
                        p = p,
                        batchsize = 10#have to have this for error checking
                        )
  
  if (method == "bab")
    simtable.bab(args)
  else if (method == "cab"){
    if (is.null(y.start))
      return(simtable.cab(args))
    else
      return(simulate.cab(args, y.start = y.start))
  }
}
summary.bab <- function(object, ...){
  cat("Number of iterations       = ", object$startiter - 1, "\n")
  cat("T degrees of freedom       = ", object$tdf, "\n")
  cat("Number of counts           = ", object$n, "\n")
  cat("df                         = ", object$n1, "\n")
  cat("Next update has nosim      = ", object$nosim, "\n")
  cat("Next update has maxiter    = ", object$maxiter, "\n")
  cat("Proportion of valid tables = ", object$perpos, "\n")
  cat("\n")
  
  rval <- as.data.frame(rbind(object$dobs, object$phat, object$mcse))
  rownames(rval) <- c("observed.stat", "pvalue", "mcse")
  colnames(rval) <- object$stat(rowlabels = TRUE)
  rval
}

summary.cab <- function(object,...){
  cat("Number of iterations       = ", object$startiter - 1, "\n")
  cat("T degrees of freedom       = ", object$tdf, "\n")
  cat("Number of counts           = ", object$n, "\n")
  cat("df                         = ", object$n1, "\n")
  cat("Number of batches          = ", object$nobatches, "\n")
  cat("Batchsize                  = ", object$batchsize, "\n")
  cat("Next update has nosim      = ", object$nosim, "\n")
  cat("Proportion of valid tables = ", object$perpos, "\n")
  cat("\n")
  
  rval <- as.data.frame(rbind(object$dobs, object$phat, object$mcse))
  rownames(rval) <- c("observed.stat", "pvalue", "mcse")
  colnames(rval) <- object$stat(rowlabels = TRUE)
  rval
}

##this is the workhorse program for the
##booth and butler method
##it requires the input to be of the form
##constructed by mcexact
update.bab <- function(object, ...){
  bab(object, ...)
}

bab <-  function(args, nosim = NULL, maxiter = NULL, savechain = FALSE){
  if (!is.null(nosim)) args$nosim <- nosim
  if (!is.null(maxiter)) args$maxiter <- maxiter
  if (args$nosim >  args$maxiter) {
    warning("update.bab, nosim > maxiter, setting maxiter = nosim")
    args$maxiter <- args$nosim
  }
  if (is.null(args$startiter)) args$startiter <- 1
  if (is.null(args$sumdw)) args$sumdw <- 0
  if (is.null(args$sumdwsq)) args$sumdwsq <- 0
  if (is.null(args$sumw)) args$sumw <- 0
  if (is.null(args$sumwsq)) args$sumwsq <- 0
  if (savechain){
    nocol <- length(args$stat(rowlabels = TRUE))
    args$chain <- matrix(0, args$nosim, nocol + 1)
    colnames(args$chain) <- c(args$stat(rowlabels = TRUE),  "log imp weight")
  }
  else args$chain <- NULL
  perpos <- 0

  i <- args$startiter
  j <- 0
  while ((i - args$startiter < args$nosim) & (j < args$maxiter)){
     j <- j + 1
     shuffle <- sample(1 : args$n1)
     conde1.permute <-  args$conde1[shuffle]
     ##this is equal to P %*% condv1 %*% t(P)
     condv1.permute <-  args$condv1[shuffle, shuffle]
     ##get and unshuffle y1.new
     ##note conde1.permute and condv1.permute
     ##are now the sequential means and variances
     y1.new.permute <- .Call("multinormfull",
                             conde1.permute,
                             condv1.permute,
                             args$tdf,
                             PACKAGE="exactLoglinTest")
     y1.new <- y1.new.permute[order(shuffle)]
     y2.new <- args$x2invt %*% (args$s - t(args$x1) %*% y1.new)
     ##though technically y.new has to be an interger we
     ##coerce here since the calculation
     ##of y.new is done as double
     y.new <- round(c(y1.new, y2.new))
     if (all(y.new >= 0)){
       perpos <- perpos + 1
       d <- args$stat(y = y.new, mu = args$mu.hat, rowlabels = FALSE)
       ##importance weights on the log scale
       w <- args$dens(y.new) - rounded.tprob(y1.new.permute, conde1.permute, diag(condv1.permute), args$tdf)
       ##the following subtracts off a constant from
       ##all of the importance weights
       ##the constant is the weight of the first
       ##simulated table
       if (savechain) args$chain[i - args$startiter + 1,] <- c(d, w)
       
       if (i == 1) args$impconst <- w
       w <- exp(w - args$impconst)
       ##the following are the partial sums required for the
       ##importance sampling estimate
       args$sumdw <- args$sumdw + (d >= args$dobs) * w
       args$sumdwsq <- args$sumdwsq + (d >= args$dobs) * w ^ 2
       args$sumw <- args$sumw  + w
       args$sumwsq <- args$sumwsq + w ^ 2
       i <- i + 1
     }
   }
  if (i == args$startiter)
    warning("No valid tables found")
  else if (savechain){
    args$chain <- args$chain[1 : (i - args$startiter),]
  }
 # if ((i - args$startiter) < (args$nosim - 1))
 #  warning("Maximum iterations reached yet desired number of simulated tables not attained.")
  args$startiter <- i
  theta <- args$sumdw / args$sumw
  setheta <- sqrt((1 - 2 * theta) * args$sumdwsq + (theta ^ 2) * args$sumwsq) / (i - 1);
  args$phat <- theta
  args$mcse <- setheta
  args$perpos <- perpos / args$nosim
  return(args)
}

###this is the workhorse program for the booth and butler method
###it requires the input to be of the form constructed by mcexact
update.cab <- function(object,...){
  cab(object, ...) 
}

cab <- function(args, nosim = NULL, batchsize = NULL, savechain = FALSE, p = NULL, flush = FALSE){
  ##error checking and initializing
  if (!is.null(p)){
    if (!is.real(p)) stop("p must be real valued")
    else if ((p < 0) | (p > 1)) stop("p must be in [0,1]")
  }
  if (!is.null(batchsize)) args$batchsize <- batchsize
  if (!is.null(nosim)) args$nosim <- nosim
  if (is.null(args$y.start)) y.start <- args$y
  if (is.null(args$startiter) | flush) args$startiter <- 1
  if (is.null(args$phat) | flush) phat <- 0
  else phat <- args$phat
  if (is.null(args$mhap) | flush) args$mhap <- 0
  if (is.null(args$bmsq) | flush) bmsq <- 0
  else bmsq <- args$bmsq
  if (is.null(args$nobatches) | flush) nobatches <- 0
  else nobatches <- args$nobatches
  if (is.null(args$current.batchmean)) current.batchmean <- 0
  else current.batchmean <- args$current.batchmean
  if (savechain){
    nocol <- length(args$stat(rowlabels = TRUE))
    args$chain <- matrix(0, args$nosim, nocol)
    colnames(args$chain) <- args$stat(rowlabels = TRUE)
  }
  else args$chain <- NULL
  
  perpos <- 0
  y.old <- y.start
  y1.old <- y.start[1 : args$n1]  
  for (i in args$startiter : (args$startiter + args$nosim - 1)){    
    shuffle <- sample(1 : args$n1)
    conde1.permute <-  args$conde1[shuffle]
    condv1.permute <-  args$condv1[shuffle, shuffle]
    ##k is the number of elements to be left fixed
    k <- rbinom(1, args$n1 - 1, args$p)
    ##separate y1 into those that stay the same and
    ##those that get updated
    y1.old.permute <- y1.old[shuffle]
    if (k > 0)
      staysfixed <- y1.old.permute[1 : k]
    else
       staysfixed <- NULL
    getsupdated <- y1.old.permute[(k + 1) : args$n1]
    ##multinorm calculates the required mean for
    ##going backwards
    temp <- .Call("multinorm",
                  conde1.permute,
                  condv1.permute,
                  as.real(staysfixed),
                  as.real(y1.old.permute),
                  args$tdf,
                  as.integer(k),
                  PACKAGE="exactLoglinTest")
    ##the ones that got updated
    y1.new.permute <- temp[[1]]
    conde1.old.permute <- temp[[2]]
    changed <- y1.new.permute[(k + 1) : args$n1]
    tf2 <- tf3 <- FALSE
    tf1 <- all(changed >= 0)
    if (tf1){
      y1.new <- y1.new.permute[order(shuffle)]
      y2.new <- round(args$x2invt %*% (args$s - t(args$x1) %*% y1.new))
      tf2 <- all(y2.new >= 0) 
      if (tf2){
        perpos <- perpos + 1
        y.new <- c(y1.new, y2.new)
        target.new <- args$dens(y.new)
        target.old <- args$dens(y.old)
        mean.new <- conde1.permute[(k + 1) : args$n1]
        mean.old <- conde1.old.permute[(k + 1) : args$n1]
        var.old.new <- diag(condv1.permute)[(k + 1) : args$n1]
        cand.new <- rounded.tprob(changed    , mean.new, var.old.new, args$tdf)
        cand.old <- rounded.tprob(getsupdated, mean.old, var.old.new, args$tdf) 
        w <- target.new +  cand.old - target.old - cand.new
        tf3 <- log(runif(1)) <=  w
      }
    }
    if (all(tf1, tf2, tf3)){
      y.old <- y.new
      y1.old <- y1.new
      args$mhap <- args$mhap + 1
    }
    else {
      y.new <- y.old
      ##not necessary but just to remind you
      ##y.old remains y.old for the next iteration
      ##y1.old remains y1.old for the next iteration
    }
    d <- args$stat(y = y.new, mu = args$mu.hat, rowlabels = FALSE)
    if (savechain) args$chain[i - args$startiter + 1,] <- d
    
    phat <- (phat * (i - 1) + (d >= args$dobs)) / i
    ##upddate the batchmean estimate
    batch.iter <- i %% args$batchsize + 1
    if (batch.iter == 1) current.batchmean <- 0
    current.batchmean <- (current.batchmean * (batch.iter - 1) + (d >= args$dobs)) / batch.iter
    if (batch.iter == args$batchsize) {
      bmsq <- bmsq + current.batchmean ^ 2
      nobatches <- nobatches + 1
    }
  }
  args$startiter <- i + 1
  ##keep the current batchmean in case simulation is restarted
  args$current.batchmean <- current.batchmean
  args$bmsq <- bmsq
  args$nobatches <- nobatches
  args$phat <- phat
  args$mcse <- sqrt((bmsq /  nobatches - phat ^ 2) / nobatches)
  args$y1.start <- y1.new
  args$perpos <- perpos / args$nosim
  return(args)
}

