#'@rdname CST_LogisticReg
#'@title Downscaling using interpolation and logistic regression.
#' 
#'@author J. Ramon, \email{jaumeramong@gmail.com}
#'@author E. Duzenli, \email{eren.duzenli@bsc.es}
#'
#'@description This function performs a downscaling using an interpolation and a logistic 
#'regression. See \code{\link[nnet]{multinom}} for further details. It is recommended that 
#'the observations are passed already in the target grid. Otherwise, the function will also 
#'perform an interpolation of the observed field into the target grid. The coarse scale and 
#'observation data can be either global or regional. In the latter case, the region is 
#'defined by the user. In principle, the coarse and observation data are intended to be of 
#'the same variable, although different variables can also be admitted. 
#'
#'@param exp an 's2dv object' with named dimensions containing the experimental field on
#'the coarse scale for which the downscaling is aimed. The object must have, at least,
#'the dimensions latitude, longitude, start date and member. The object is expected to be 
#'already subset for the desired region. Data can be in one or two integrated regions, e.g.,
#'crossing the Greenwich meridian. To get the correct results in the latter case,
#'the borders of the region should be specified in the parameter 'region'. See parameter
#''region'.
#'@param obs an 's2dv object' with named dimensions containing the observational field. 
#'The object must have, at least, the dimensions latitude, longitude and start date. The 
#'object is expected to be already subset for the desired region. 
#'@param exp_cor an optional array with named dimensions containing the seasonal forecast
#'experiment data. If the forecast is provided, it will be downscaled using the hindcast and
#'observations; if not provided, the hindcast will be downscaled instead. The default value
#'is NULL.
#'@param target_grid a character vector indicating the target grid to be passed to CDO.
#'It must be a grid recognised by CDO or a NetCDF file.
#'@param int_method a character vector indicating the regridding method to be passed to CDORemap.
#'Accepted methods are "con", "bil", "bic", "nn", "con2". If "nn" method is to be used, CDO_1.9.8
#'or newer version is required. For method "con2", CDO_2.2.2 or older version is required.
#'@param log_reg_method a character vector indicating the logistic regression method to be used.
#'Accepted methods are "ens_mean", "ens_mean_sd", "sorted_members". "ens_mean" uses the ensemble
#'mean anomalies as predictors in the logistic regression, "ens_mean_sd" uses the ensemble
#'mean anomalies and the ensemble spread (computed as the standard deviation of all the members) 
#'as predictors in the logistic regression, and "sorted_members" considers all the members
#'ordered decreasingly as predictors in the logistic regression. Default method is "ens_mean".
#'@param probs_cat a numeric vector indicating the percentile thresholds separating the 
#'climatological distribution into different classes (categories). Default to c(1/3, 2/3). See 
#'\code{\link[easyVerification]{convert2prob}}.
#'@param return_most_likely_cat if TRUE, the function returns the most likely category. If 
#'FALSE, the function returns the probabilities for each category. Default to FALSE.
#'@param points a list of two elements containing the point latitudes and longitudes
#'of the locations to downscale the model data. The list must contain the two elements
#'named as indicated in the parameters 'lat_dim' and 'lon_dim'. If the downscaling is
#'to a point location, only regular grids are allowed for exp and obs. Only needed if the
#'downscaling is to a point location.
#'@param method_point_interp a character vector indicating the interpolation method to interpolate 
#'model gridded data into the point locations. Accepted methods are "nearest", "bilinear", "9point", 
#'"invdist4nn", "NE", "NW", "SE", "SW". Only needed if the downscaling is to a point location.
#'@param lat_dim a character vector indicating the latitude dimension name in the element 'data' 
#'in exp and obs. Default set to "lat".
#'@param lon_dim a character vector indicating the longitude dimension name in the element 'data' 
#'in exp and obs. Default set to "lon".
#'@param sdate_dim a character vector indicating the start date dimension name in the element 
#''data' in exp and obs. Default set to "sdate".
#'@param member_dim a character vector indicating the member dimension name in the element
#''data' in exp and obs. Default set to "member".
#'@param time_dim a character vector indicating the time dimension name in the element
#''data' in exp and obs. Default set to "time".  
#'@param region a numeric vector indicating the borders of the downscaling region.
#'It consists of four elements in this order: lonmin, lonmax, latmin, latmax. lonmin refers
#'to the left border, while lonmax refers to the right border. latmin indicates the lower
#'border, whereas latmax indicates the upper border. If set to NULL (default), the function
#'takes the first and last elements of the latitudes and longitudes in obs.
#'@param loocv a logical vector indicating whether to perform leave-one-out cross-validation
#'in the fitting of the logistic regression. In this procedure, all values from the 
#'corresponding year are excluded, so that when fitting the model for a given year, none 
#'of that year’s data is used. Default to TRUE.
#'@param ncores an integer indicating the number of cores to use in parallel computation. 
#'The default value is NULL.
#'@import multiApply
#'@import nnet
#'@import plyr
#'@import abind
#'@import s2dv
#'@importFrom ClimProjDiags Subset
#'@importFrom stats sd
#'
#'@seealso \code{\link[nnet]{multinom}}
#'
#'@return A list with two s2dv_cube objects, exp and obs, each with elements 'data' 
#'containing the downscaled data, that could be either
#'in the form of probabilities for each category or the most likely category. 
#''coords' containing the coordinate information, 'dims' describing the 
#'dimension structure, and 'attrs' containing the associated attributes.
#'
#'@examples
#'\donttest{
#'exp <- rnorm(1500) 
#'dim(exp) <- c(member = 5, lat = 4, lon = 5, sdate = 15) 
#'exp_lons <- 1:5 
#'exp_lats <- 1:4 
#'obs <- rnorm(2700) 
#'dim(obs) <- c(lat = 12, lon = 15, sdate = 15) 
#'obs_lons <- seq(1,5, 4/14) 
#'obs_lats <- seq(1,4, 3/11) 
#'exp <- CSTools::s2dv_cube(data = exp, coords = list(lat = exp_lats, lon = exp_lons))
#'obs <- CSTools::s2dv_cube(data = obs, coords = list(lat = obs_lats, lon = obs_lons))
#'if (Sys.which("cdo") != "") {
#'res <- CST_LogisticReg(exp = exp, obs = obs, int_method = 'bil', target_grid = 'r1280x640', 
#'                       probs_cat = c(1/3, 2/3))
#'}
#'}
#'@export
CST_LogisticReg <- function(exp, obs, exp_cor = NULL, target_grid, int_method = NULL, 
                            log_reg_method = "ens_mean", probs_cat = c(1/3,2/3), 
                            return_most_likely_cat = FALSE, points = NULL, 
                            method_point_interp = NULL, lat_dim = "lat", lon_dim = "lon", 
                            sdate_dim = "sdate", member_dim = "member", time_dim = "time", 
                            region = NULL, loocv = TRUE, ncores = NULL) {

  if (!inherits(exp,'s2dv_cube')) {
    stop("Parameter 'exp' must be of the class 's2dv_cube'")
  }

  if (!inherits(obs,'s2dv_cube')) {
    stop("Parameter 'obs' must be of the class 's2dv_cube'")
  }

  res <- LogisticReg(exp = exp$data, obs = obs$data, exp_cor = exp_cor$data, 
                     exp_lats = exp$coords[[lat_dim]], exp_lons = exp$coords[[lon_dim]], 
                     obs_lats = obs$coords[[lat_dim]], obs_lons = obs$coords[[lon_dim]], 
                     target_grid = target_grid, probs_cat = probs_cat, 
                     return_most_likely_cat = return_most_likely_cat,
                     int_method = int_method, log_reg_method = log_reg_method, points = points, 
                     method_point_interp = method_point_interp, lat_dim = lat_dim, 
                     lon_dim = lon_dim, sdate_dim = sdate_dim, member_dim = member_dim, 
                     time_dim = time_dim, source_file_exp = exp$attrs$source_files[1], 
                     source_file_obs = obs$attrs$source_files[1],
                     region = region, loocv = loocv, ncores = ncores)

  # Modify data, lat and lon in the origina s2dv_cube, adding the downscaled data
  obs$data <- res$obs
  obs$dims <- dim(obs$data)
  obs$coords[[lon_dim]] <- res$lon
  obs$coords[[lat_dim]] <- res$lat

  if (is.null(exp_cor)) {
    exp$data <- res$data
    exp$dims <- dim(exp$data)
    exp$coords[[lon_dim]] <- res$lon
    exp$coords[[lat_dim]] <- res$lat

    res_s2dv <- list(exp = exp, obs = obs)
  } else {
    exp_cor$data <- res$data
    exp_cor$dims <- dim(exp_cor$data)
    exp_cor$coords[[lon_dim]] <- res$lon
    exp_cor$coords[[lat_dim]] <- res$lat

    res_s2dv <- list(exp = exp_cor, obs = obs)
  }

  return(res_s2dv)
}

#'@rdname LogisticReg
#'@title Downscaling using interpolation and logistic regression.
#' 
#'@author J. Ramon, \email{jaumeramong@gmail.com}
#'@author E. Duzenli, \email{eren.duzenli@bsc.es}
#'
#'@description This function performs a downscaling using an interpolation and a logistic 
#'regression. See \code{\link[nnet]{multinom}} for further details. It is recommended that 
#'the observations are passed already in the target grid. Otherwise, the function will also 
#'perform an interpolation of the observed field into the target grid. The coarse scale and 
#'observation data can be either global or regional. In the latter case, the region is 
#'defined by the user. In principle, the coarse and observation data are intended to be of 
#'the same variable, although different variables can also be admitted. 
#'
#'@param exp an array with named dimensions containing the experimental field on the
#'coarse scale for which the downscaling is aimed. The object must have, at least,
#'the dimensions latitude, longitude, start date and member. The object is expected to be 
#'already subset for the desired region. Data can be in one or two integrated regions, e.g.,
#'crossing the Greenwich meridian. To get the correct results in the latter case,
#'the borders of the region should be specified in the parameter 'region'. See parameter
#''region'.
#'@param obs an array with named dimensions containing the observational field. The object 
#'must have, at least, the dimensions latitude, longitude and start date. The object is 
#'expected to be already subset for the desired region. 
#'@param exp_cor an optional array with named dimensions containing the seasonal forecast
#'experiment data. If the forecast is provided, it will be downscaled using the hindcast and
#'observations; if not provided, the hindcast will be downscaled instead. The default value
#'is NULL.
#'@param exp_lats a numeric vector containing the latitude values in 'exp'. Latitudes must 
#'range from -90 to 90.
#'@param exp_lons a numeric vector containing the longitude values in 'exp'. Longitudes 
#'can range from -180 to 180 or from 0 to 360.
#'@param obs_lats a numeric vector containing the latitude values in 'obs'. Latitudes must
#'range from -90 to 90.
#'@param obs_lons a numeric vector containing the longitude values in 'obs'. Longitudes
#'can range from -180 to 180 or from 0 to 360.
#'@param target_grid a character vector indicating the target grid to be passed to CDO.
#'It must be a grid recognised by CDO or a NetCDF file.
#'@param int_method a character vector indicating the regridding method to be passed to CDORemap.
#'Accepted methods are "con", "bil", "bic", "nn", "con2". If "nn" method is to be used, CDO_1.9.8
#'or newer version is required. For method "con2", CDO_2.2.2 or older version is required.
#'@param log_reg_method a character vector indicating the logistic regression method to be used.
#'Accepted methods are "ens_mean", "ens_mean_sd", "sorted_members". "ens_mean" uses the ensemble
#'mean anomalies as predictors in the logistic regression, "ens_mean_sd" uses the ensemble
#'mean anomalies and the ensemble spread (computed as the standard deviation of all the members) 
#'as predictors in the logistic regression, and "sorted_members" considers all the members
#'ordered decreasingly as predictors in the logistic regression. Default method is "ens_mean".
#'@param probs_cat a numeric vector indicating the percentile thresholds separating the 
#'climatological distribution into different classes (categories). Default to c(1/3, 2/3). See 
#'\code{\link[easyVerification]{convert2prob}}.
#'@param return_most_likely_cat if TRUE, the function returns the most likely category. If 
#'FALSE, the function returns the probabilities for each category. Default to FALSE.
#'@param points a list of two elements containing the point latitudes and longitudes
#'of the locations to downscale the model data. The list must contain the two elements
#'named as indicated in the parameters 'lat_dim' and 'lon_dim'. If the downscaling is
#'to a point location, only regular grids are allowed for exp and obs. Only needed if the
#'downscaling is to a point location.
#'@param method_point_interp a character vector indicating the interpolation method to interpolate 
#'model gridded data into the point locations. Accepted methods are "nearest", "bilinear", "9point", 
#'"invdist4nn", "NE", "NW", "SE", "SW". Only needed if the downscaling is to a point location.
#'@param lat_dim a character vector indicating the latitude dimension name in the element 'data' 
#'in exp and obs. Default set to "lat".
#'@param lon_dim a character vector indicating the longitude dimension name in the element 'data' 
#'in exp and obs. Default set to "lon".
#'@param sdate_dim a character vector indicating the start date dimension name in the element 
#''data' in exp and obs. Default set to "sdate".
#'@param member_dim a character vector indicating the member dimension name in the element
#''data' in exp and obs. Default set to "member".
#'@param time_dim a character vector indicating the time dimension name in the element
#''data' in exp and obs. Default set to "time".  
#'@param source_file_exp a character vector with a path to an example file of the exp data.
#'Only needed if the downscaling is to a point location.
#'@param source_file_obs a character vector with a path to an example file of the obs data.
#'Only needed if the downscaling is to a point location.
#'@param region a numeric vector indicating the borders of the downscaling region.
#'It consists of four elements in this order: lonmin, lonmax, latmin, latmax. lonmin refers
#'to the left border, while lonmax refers to the right border. latmin indicates the lower
#'border, whereas latmax indicates the upper border. If set to NULL (default), the function
#'takes the first and last elements of the latitudes and longitudes in obs.
#'@param loocv a logical vector indicating whether to perform leave-one-out cross-validation
#'in the fitting of the logistic regression. In this procedure, all values from the 
#'corresponding year are excluded, so that when fitting the model for a given year, none 
#'of that year’s data is used. Default to TRUE.
#'@param ncores an integer indicating the number of cores to use in parallel computation. 
#'The default value is NULL.
#'@import multiApply
#'@import nnet
#'@import plyr
#'@import abind
#'@import s2dv
#'@importFrom ClimProjDiags Subset
#'@importFrom easyVerification convert2prob
#'
#'@seealso \code{\link[nnet]{multinom}}
#'
#'@return A list of three elements. 'data' contains the dowscaled data, that could be either
#'in the form of probabilities for each category or the most likely category. 'lat' contains the
#'downscaled latitudes, and 'lon' the downscaled longitudes.  
#'
#'@examples
#'\donttest{
#'exp <- rnorm(1500) 
#'dim(exp) <- c(member = 5, lat = 4, lon = 5, sdate = 15) 
#'exp_lons <- 1:5 
#'exp_lats <- 1:4 
#'obs <- rnorm(2700) 
#'dim(obs) <- c(lat = 12, lon = 15, sdate = 15) 
#'obs_lons <- seq(1,5, 4/14) 
#'obs_lats <- seq(1,4, 3/11) 
#'if (Sys.which("cdo") != "") {
#'res <- LogisticReg(exp = exp, obs = obs, exp_lats = exp_lats, exp_lons = exp_lons, 
#'                   obs_lats = obs_lats, obs_lons = obs_lons, int_method = 'bil', 
#'                   target_grid = 'r1280x640', probs_cat = c(1/3, 2/3))
#'}
#'}
#'@export
LogisticReg <- function(exp, obs, exp_cor = NULL, exp_lats, exp_lons, obs_lats, obs_lons, 
                        target_grid, int_method = NULL, log_reg_method = "ens_mean", 
                        probs_cat = c(1/3,2/3), return_most_likely_cat = FALSE, points = NULL, 
                        method_point_interp = NULL, lat_dim = "lat", lon_dim = "lon", 
                        sdate_dim = "sdate", member_dim = "member", time_dim = "time", 
                        source_file_exp = NULL, source_file_obs = NULL, region = NULL, 
                        loocv = TRUE, ncores = NULL) {

  #-----------------------------------
  # Checkings
  #-----------------------------------

  if (!is.null(int_method) & !inherits(int_method, 'character')) {
    stop("Parameter 'int_method' must be of the class 'character'")
  }

  if (!is.null(method_point_interp) & !inherits(method_point_interp, 'character')) {
    stop("Parameter 'method_point_interp' must be of the class 'character'")
  }

  if (!inherits(lat_dim, 'character')) {
    stop("Parameter 'lat_dim' must be of the class 'character'")
  }

  if (!inherits(lon_dim, 'character')) {
    stop("Parameter 'lon_dim' must be of the class 'character'")
  }

  if (!inherits(sdate_dim, 'character')) {
    stop("Parameter 'sdate_dim' must be of the class 'character'")
  }

  if (!inherits(member_dim, 'character')) {
    stop("Parameter 'member_dim' must be of the class 'character'")
  }

  if (!is.null(source_file_exp) & !inherits(source_file_exp, 'character')) {
    stop("Parameter 'source_file_exp' must be of the class 'character'")
  }

  if (!is.null(source_file_obs) & !inherits(source_file_obs, 'character')) {
    stop("Parameter 'source_file_obs' must be of the class 'character'")
  }

  if (!inherits(loocv, 'logical')) {
    stop("Parameter 'loocv' must be set to TRUE or FALSE")
  }

  if (is.na(match(lon_dim, names(dim(exp))))) {
    stop("Missing longitude dimension in 'exp', or does not match the parameter ",
         "'lon_dim'")
  }

  if (is.na(match(lat_dim, names(dim(exp))))) {
    stop("Missing latitude dimension in 'exp', or does not match the parameter ",
         "'lat_dim'")
  }

  if (is.na(match(sdate_dim, names(dim(exp)))) | is.na(match(sdate_dim, names(dim(obs))))) {
    stop("Missing start date dimension in 'exp' and/or 'obs', or does not match the parameter ",
         "'sdate_dim'")
  }

  if (is.na(match(member_dim, names(dim(exp))))) {
    stop("Missing member dimension in 'exp' and/or 'obs', or does not match the parameter ",
         "'member_dim'")
  }

  if (!is.null(exp_cor)) {

    if (is.na(match(sdate_dim, names(dim(exp_cor))))) {
      stop("Missing start date dimension in 'exp_cor', or does not match the parameter ",
           "'sdate_dim'")
    }

    if (is.na(match(member_dim, names(dim(exp_cor))))) {
      stop("Missing member dimension in 'exp_cor', or does not match the parameter 'member_dim'")
    }

    if (is.na(match(lon_dim, names(dim(exp_cor))))) {
      stop("Missing longitude dimension in 'exp_cor', or does not match the parameter ",
           "'lon_dim'")
    }

    if (is.na(match(lat_dim, names(dim(exp_cor))))) {
        stop("Missing latitude dimension in 'exp_cor', or does not match the parameter ",
             "'lat_dim'")
    }

    if (loocv) { # loocv equal to false to train with the whole hindcast and predict with the forecast
      loocv <- FALSE
      warning("Forecast data will be downscaled. 'loocv' is set to FALSE ",
              "to train the model with the whole hindcast and predict with the forecast.")
    }
  }

  # When observations are pointwise
  if (!is.null(points) & !is.na(match("location", names(dim(obs))))) {
    point_obs <- T
    # dimension aux in obs is needed 
    if (is.na(match("aux", names(dim(obs))))) {
      obs <- InsertDim(obs, posdim = 1, lendim = 1, name = "aux")
    }
  } else {
    point_obs <- F
  }

  if (!is.null(points) & (is.null(source_file_exp))) {
    stop("No source file found. Source file must be provided in the parameter 'source_file_exp'.")
  }

  if (!is.null(points) & is.null(method_point_interp)) {
    stop("Please provide the interpolation method to interpolate gridded data to point locations ",
         "through the parameter 'method_point_interp'.")
  }

  if (is.null(region)) {
    warning("The borders of the downscaling region have not been provided. Assuming ",
            "the four borders of the downscaling region are defined by the first and ",
            "last elements of the parametres 'obs_lats' and 'obs_lons'.")
    region <- c(obs_lons[1], obs_lons[length(obs_lons)], obs_lats[1], obs_lats[length(obs_lats)])
  }
  ## ncores
  if (!is.null(ncores)) {
    if (!is.numeric(ncores) | any(ncores %% 1 != 0) | any(ncores < 0) |
        length(ncores) > 1) {
      stop("Parameter 'ncores' must be a positive integer.")
    }
  }
  
  # the code is not yet prepared to handle members in the observations
  restore_ens <- FALSE
  if (member_dim %in% names(dim(obs))) {
    if (identical(as.numeric(dim(obs)[member_dim]), 1)) {
      restore_ens <- TRUE
      obs <- ClimProjDiags::Subset(x = obs, along = member_dim, indices = 1, drop = 'selected')
    } else {
      stop("Not implemented for observations with members ('obs' can have 'member_dim', ",
           "but it should be of length = 1).")
    }
  }

  if (.check_coords(lat1 = exp_lats, lat2 = obs_lats,
                    lon1 = exp_lons, lon2 = obs_lons)) {
    exp_interpolated <- NULL
    exp_interpolated$data <- exp
    exp_interpolated$lat <- exp_lats
    exp_interpolated$lon <- exp_lons
    exp_cor_interpolated <- NULL
    if (!is.null(exp_cor)) {
      exp_cor_interpolated <- NULL
      exp_cor_interpolated$data <- exp_cor
    }
    obs_ref <- obs
  } else {
    exp_interpolated <- Interpolation(exp = exp, lats = exp_lats, lons = exp_lons, 
                                      target_grid = target_grid, method_remap = int_method, 
                                      points = points, source_file = source_file_exp,
                                      lat_dim = lat_dim, lon_dim = lon_dim, 
                                      method_point_interp = method_point_interp,
                                      region = region, ncores = ncores)

    if (!is.null(exp_cor)) {
      exp_cor_interpolated <- Interpolation(exp = exp_cor, lats = exp_lats, lons = exp_lons, 
                                            target_grid = target_grid, points = points, 
                                            method_point_interp = method_point_interp,
                                            source_file = source_file_exp, lat_dim = lat_dim, 
                                            lon_dim = lon_dim, method_remap = int_method, 
                                            region = region, ncores = ncores)
    }
    # If after interpolating 'exp' data the coordinates do not match, the obs data is interpolated to
    # the same grid to force the matching
    if ((!.check_coords(lat1 = as.numeric(exp_interpolated$lat), lat2 = obs_lats,
                        lon1 = as.numeric(exp_interpolated$lon), lon2 = obs_lons)) | !(point_obs)) {
      obs_interpolated <- Interpolation(exp = obs, lats = obs_lats, lons = obs_lons,
                                        target_grid = target_grid, method_remap = int_method,
                                        points = points, source_file = source_file_obs,
                                        lat_dim = lat_dim, lon_dim = lon_dim,
                                        method_point_interp = method_point_interp,
                                        region = region, ncores = ncores)
      obs_ref <- obs_interpolated$data
    } else {
      obs_ref <- obs
    }
  }

  # compute ensemble mean anomalies
  if (log_reg_method == "ens_mean") {
    predictor <- .get_ens_mean_anom(obj_ens = exp_interpolated$data, member_dim = member_dim, 
                                    sdate_dim = sdate_dim, ncores = ncores)

    target_dims_predictor <- sdate_dim
    if (!is.null(exp_cor)) {
      clim_hcst <- Apply(exp_interpolated$data, target_dims = c(member_dim, sdate_dim), 
                         mean, ncores = ncores, 
                         na.rm = TRUE)$output1 ## climatology of hindcast ens mean
      ens_mean_fcst <- Apply(exp_cor_interpolated$data, target_dims = member_dim, 
                             mean, na.rm = TRUE, ncores = ncores)$output1 ## ens mean of the forecast
      forecast <- Ano(ens_mean_fcst, clim_hcst, ncores = ncores) 
      target_dims_forecast <- sdate_dim
    }
  } 
  else if (log_reg_method == "ens_mean_sd")  {

    ens_mean_anom <- .get_ens_mean_anom(obj_ens = exp_interpolated$data, member_dim = member_dim, 
                                        sdate_dim = sdate_dim, ncores = ncores)
    ens_sd <- .get_ens_sd(obj_ens = exp_interpolated$data, member_dim = member_dim, ncores = ncores)

    #merge two arrays into one array of predictors
    predictor <- abind(ens_mean_anom, ens_sd, along = 1/2)
    names(dim(predictor)) <- c("pred", names(dim(ens_mean_anom)))
     
    target_dims_predictor <- c(sdate_dim, "pred")

    if (!is.null(exp_cor)) {
      clim_hcst <- Apply(exp_interpolated$data, target_dims = c(member_dim, sdate_dim),
                         mean, ncores = ncores,
                         na.rm = TRUE)$output1 ## climatology of hindcast ens mean
      ens_mean_fcst <- Apply(exp_cor_interpolated$data, target_dims = member_dim,
                             mean, na.rm = TRUE, ncores = ncores)$output1 ## ens mean of the forecast
      forecast_mean_anom <- Ano(ens_mean_fcst, clim_hcst, ncores = ncores)
      forecast_sd <- .get_ens_sd(obj_ens = exp_cor_interpolated$data, 
                                 member_dim = member_dim, ncores = ncores)
      forecast <- abind(forecast_mean_anom, forecast_sd, along = 1/2)
      names(dim(forecast)) <- c("pred", names(dim(forecast_mean_anom)))

      target_dims_forecast <- c(sdate_dim, "pred")
    } 

  } else if (log_reg_method == "sorted_members") {
 
    if (!is.null(exp_cor)) {
      stop('sorted_members method cannot be used to downscale forecasts since ',
           'the ensemble members are generally exchangeable')
    }   
    predictor <- .sort_members(obj_ens = exp_interpolated$data, 
                               member_dim = member_dim, ncores = ncores)

    target_dims_predictor <- c(sdate_dim, member_dim)
  } else {
    stop(paste0(log_reg_method, " not recognised or not implemented."))
  } 

  k_out <- 1 ##  in case loocv = TRUE and the data is NOT daily (i.e., time_dim does not exist), 
             ## leave one data out.
  daily <- FALSE # time_dim does not exist
  if (  time_dim %in% names(dim(obs_ref)) & dim(obs_ref)[time_dim] > 1) {
    daily <- TRUE # time_dim exists
    sdate_num <- as.numeric (dim(obs_ref)[sdate_dim])
    k_out <- as.numeric (dim(obs_ref)[time_dim])  
    obs_ref <-  MergeDims (obs_ref, merge_dims = c(time_dim, sdate_dim), rename_dim = sdate_dim)
    predictor <- MergeDims (predictor, merge_dims = c(time_dim, sdate_dim), rename_dim = sdate_dim)
    if (!is.null(exp_cor)) {
      sdate_num_fr <- as.numeric (dim(forecast)[sdate_dim]) ## sdate_num of forecast
      forecast <- MergeDims (forecast, merge_dims = c(time_dim, sdate_dim), rename_dim = sdate_dim)
    }
  }
  
  # convert observations to categorical predictands
  obs_cat <- Apply(obs_ref, target_dims = sdate_dim, function(x) {
                     if (!any(!is.na(x))) {
                       rep(NA,length(x))
                     } else {
		       terc <- easyVerification::convert2prob(as.vector(x), prob = probs_cat) 
                       as.integer(apply(terc, 1, function(r) which (r == 1)))}},          
                   output_dims = sdate_dim, ncores = ncores)$output1
  
  target_dims_predictand <- sdate_dim
  # Apply the logistic regressions
  ## case hindcast only
  if (is.null(exp_cor)) {
    res <- Apply(list(predictor, obs_cat), 
                 target_dims = list(target_dims_predictor, target_dims_predictand), 
                 fun = function(x, y) .log_reg(x = x, y = y, loocv = loocv, probs_cat = probs_cat,
                                               sdate_dim = sdate_dim, k_out = k_out), 
                 output_dims = c(sdate_dim, "category"), ncores = ncores)$output1

   if (  daily ) {
     res <- SplitDim(res, split_dim = sdate_dim, new_dim_name = time_dim,
                     indices = rep(1:k_out, sdate_num))
     obs_ref <- SplitDim(obs_ref, split_dim = sdate_dim, new_dim_name = time_dim,
                         indices = rep(1:k_out, sdate_num))
   }

    if (return_most_likely_cat) {
      res <- Apply(res, target_dims = c(sdate_dim, "category"), .most_likely_category,
                   output_dims = sdate_dim, ncores = ncores)$output1
    }
  } 
  ## case hindcast - forecast
  else {
    res <- Apply(list(predictor, obs_cat, forecast), 
                 target_dims = list(target_dims_predictor, target_dims_predictand, 
                                    target_dims_forecast), 
                 fun = function(x, y, f) .log_reg(x = x, y = y, f = f, loocv = loocv, 
                                                  probs_cat = probs_cat, sdate_dim = sdate_dim,
                                                  k_out = k_out),
                 output_dims = c(sdate_dim, "category"), ncores = ncores)$output1

    if (  daily ) {
     res <- SplitDim(res, split_dim = sdate_dim, new_dim_name = time_dim,
                     indices = rep(1:k_out, sdate_num_fr))
     obs_ref <- SplitDim(obs_ref, split_dim = sdate_dim, new_dim_name = time_dim,
                         indices = rep(1:k_out, sdate_num))
    }
  }

  # restore ensemble dimension in observations if it existed originally
  if (restore_ens) {
    obs_ref <- s2dv::InsertDim(obs_ref, posdim = 1, lendim = 1, name = member_dim)
  }

  res <- list(data = res, obs = obs_ref, lon = exp_interpolated$lon, lat = exp_interpolated$lat)
  
  return(res)
}

.most_likely_category <- function(data) {
# data, expected dims: start date, category (in this order)

  if (all(is.na(data))) {
    mlc <- rep(NA, nrow(data))
  } else {
    mlc <- apply(data, 1, which.max)
  }
  return(mlc)
}

.sort_members <- function(obj_ens, member_dim, ncores = NULL) {

  sorted <- Apply(obj_ens, target_dims = member_dim, 
                  sort, decreasing = TRUE, na.last = TRUE, ncores = ncores)$output1  

  return(sorted)
}

.get_ens_sd <- function(obj_ens, member_dim, ncores = NULL) {

  # compute ensemble spread
  ens_sd <- Apply(obj_ens, target_dims = member_dim, sd, na.rm = TRUE, ncores = ncores)$output1

  return(ens_sd)
}

.get_ens_mean_anom <- function(obj_ens, member_dim, sdate_dim, ncores = NULL) {

  # compute climatology
  clim <- Apply(obj_ens, target_dims = c(member_dim, sdate_dim), 
                mean, ncores = ncores, na.rm = TRUE)$output1

  # compute ensemble mean
  ens_mean <- Apply(obj_ens, target_dims = member_dim, mean, na.rm = TRUE, ncores = ncores)$output1

  # compute ensemble mean anomalies
  anom <- Ano(ens_mean, clim, ncores = ncores)

  return(anom)
}

# atomic functions for logistic regressions
.log_reg <- function(x, y, f = NULL, loocv, probs_cat, sdate_dim, k_out = 1) {

  tmp_df <- data.frame(x = x, y = y)
  
  # if the data is all NA, force return return NA
  if (all(is.na(tmp_df)) | (sum(apply(tmp_df, 2, function(x) !all(is.na(x)))) == 1) | all(is.na(tmp_df$y))) {
    if (is.null(f)) {
      n1 <- nrow(tmp_df)
    } else {
      if (is.null(dim(f))) {
        n1 <- length(f)
      } else {
        n1 <- dim(f)[1]
      }
    }

    n2 <- length(probs_cat) + 1
    res <- matrix(NA, nrow = n1, ncol = n2)

  } else {
    # training
    lm1 <- .train_lr(df = tmp_df, loocv = loocv, k_out)

    # prediction
    res <- .pred_lr(lm1 = lm1, df = tmp_df, f = f, loocv = loocv, 
                    probs_cat = probs_cat, k_out = k_out)

    if ( !(is.null(f)) ) {  ## if forecast is provided, and syear of forecast is 1.
      if ( nrow(f) == 1 ) {
        res <-  array(res, dim = c(1, length(probs_cat) + 1)) 
      }
    }

  }
  return(res)
}

#-----------------------------------
# Function to train the logistic regressions.
#-----------------------------------
.train_lr <- function(df, loocv, k_out = 1) {

  # Remove columns containing only NA's
  df <- df[ , apply(df, 2, function(x) !all(is.na(x)))]

  if (loocv) {

    lm1 <- lapply(1:(nrow(df)/k_out), function(j) {
      window <- ((j-1)*k_out+1):((j-1)*k_out + k_out) # omit the data of the year including corresponding day
      multinom(y ~ ., data = df[ -window, ], trace = FALSE)}) 

  } else {

    lm1 <- list(multinom(y ~ ., data = df, trace = FALSE))

  }

  return(lm1)
}

#-----------------------------------
# Function to apply the logistic regressions.
#-----------------------------------
.pred_lr <- function(df, lm1, f, loocv, probs_cat, k_out = 1) {

  if (loocv) {
   # The error: "Error: Results must have the same dimensions." can 
   # appear when the number of sdates is insufficient
   
   pred_vals_ls <- list()
   for (j in 1:length(lm1)) {
    window <- ((j-1)*k_out+1):((j-1)*k_out + k_out) # test the daily data of the corresponding year
    if (any(apply(df[window, ], 2, function (x) all(is.na (x))))) {
      if( length(probs_cat) + 1 == 2) {
        pred_vals_ls[[j]] <- rep(NA, length(window) ) 
      } else {
        pred_vals_ls[[j]] <- array(rep(NA, length(window) * (length(probs_cat) + 1)),
                                       dim = c(length(window), length(probs_cat) + 1))
      }
    } else {
      pred_vals_ls[[j]] <- predict(lm1[[j]], df[window,], type = "probs") 
     }
   }
	
   pred_vals <- do.call(rbind, pred_vals_ls)
    
   if( length(probs_cat) + 1 == 2) {
     pred_vals_dum<-array(NA, dim = c(nrow(df), 2))
     pred_vals_dum[, 2] <- t(pred_vals)
     pred_vals_dum[, 1] <- 1 - pred_vals_dum[, 2]
     pred_vals <- pred_vals_dum
     colnames(pred_vals) <- c(1, 2)
   } 
    
  } else {

    # type = class, probs
    #pred_vals_ls <- lapply(lm1, predict, data = df, type = "probs")
    #pred_vals <- unlist(pred_vals_ls)
    if (is.null(f)) {
      pred_vals <- predict(lm1[[1]], df, type = "probs")
	
      if( length(probs_cat) + 1 == 2) {
        pred_vals_dum <- array(NA, dim = c(nrow(df), 2))
        pred_vals_dum[, 2] <- t(pred_vals)
        pred_vals_dum[, 1] <- 1 - pred_vals_dum[, 2]
        pred_vals<-pred_vals_dum
        colnames(pred_vals) <- c(1,2)
      }

    } else {
      if (is.null(dim(f))) {
        pred_vals <- predict(lm1[[1]], newdata = data.frame(x = as.vector(f)), type = "probs")
      } else {
        pred_vals <- predict(lm1[[1]], newdata = data.frame(x = f), type = "probs")      
      }
      if (length(probs_cat) + 1 == 2) {
        if (is.null(dim(f))) {
          pred_vals_dum <- matrix(NA, nrow = length(f), ncol = 2)
        } else {
          pred_vals_dum <- matrix(NA, nrow = dim(f)[1], ncol = 2)
        }
        pred_vals_dum[, 2] <- pred_vals
        pred_vals_dum[, 1] <- 1 - pred_vals
        pred_vals <- pred_vals_dum
        colnames(pred_vals) <- c(1,2)
      }
    }
  }

  return(pred_vals)
}
