#include <RcppArmadillo.h>
#include "integratemvn.h"
#include "tab2mat.h"
// [[Rcpp::depends(RcppArmadillo)]]

using namespace Rcpp;

//' Integrate over Random Effects
//'
//' Used to conduct Monte Carlo integration over Gaussian random effects.
//' Not intended to be called directly by most users.
//'
//' @param d A list with model matrices for each random effect block.
//' @param sd A list with standard deviation matrices for each random effect block
//'   where rows are different posterior draws.
//' @param L A list with matrices for each random effect block containing the parts of
//'   the L matrix, the Cholesky decomposition of the random effect correlation matrix.
//' @param k An integer, the number of samples for Monte Carlo integration.
//' @param yhat A matrix of the fixed effects predictions
//' @param backtrans An integer, indicating the type of back transformation.
//'   0 indicates inverse logit (e.g., for logistic regression).
//'   1 indicates exponential (e.g., for poisson or negative binomial regression or if outcome was natural log transformed).
//'   2 indicates square (e.g., if outcome was square root transformed).
//'   3 indicates inverse (e.g., if outcome was inverse transformed such as Gamma regression)
//'   Any other integer results in no transformation. -9 is recommended as the option for no
//'   transformation as any future transformations supported will be other, positive integers.
//' @return A numeric matrix with the Monte Carlo integral calculated.
//' @export
//' @examples
//' integratere(
//'   d = list(matrix(1, 1, 1)),
//'   sd = list(matrix(1, 2, 1)),
//'   L = list(matrix(1, 2, 1)),
//'   k = 10L,
//'   yhat = matrix(0, 2, 1),
//'   backtrans = 0L)
// [[Rcpp::export]]
arma::mat integratere(List d, List sd, List L, int k, const arma::mat& yhat, int backtrans) {
  int M = yhat.n_rows;
  int N = yhat.n_cols;
  int J = sd.length();
  
  arma::mat yhat2 = arma::zeros(M, N);
  
  for (int i = 0; i < M; i++) {
    List Z(J);
    
    for (int re = 0; re < J; re++) {
      NumericMatrix x = L[re];
      arma::mat xmat = arma::mat(x.begin(), x.nrow(), x.ncol(), false);
      arma::mat cholmat = tab2mat(xmat.row(i));
      arma::mat dmat = d[re];

      NumericMatrix sdmat = sd[re];
      NumericVector sdvec_nv = sdmat(i, _);

      // convert to arma::vec
      arma::rowvec sdvec(sdvec_nv.begin(), sdvec_nv.size(),
                      /*copy_aux_mem=*/false, /*strict=*/true);

      Z[re] = integratemvn(dmat, k, sdvec, cholmat);
    }

    // initialize matrix for all random effect predictions
    arma::mat Zall = Z[0];
    if (J > 0) {
      for (int re = 1; re < J; re++) {
      	arma::mat tmp = Z[re];
      	Zall += tmp;
      }
    }

    Zall.each_col() += yhat.row(i).t();

    switch (backtrans) {
      case 0: Zall = 1.0 / (1.0 + arma::exp(-Zall)); break;
      case 1: Zall = arma::exp(Zall); break;
      case 2: Zall %= Zall; break;
      case 3: Zall = 1.0 / Zall; break;
      default: break;
    }
    arma::colvec zm = arma::mean(Zall, 1);
    yhat2.row(i) = zm.t();
  }
  return(yhat2);
}
