#===============================================================================
# WGCNA Power Selection Utilities
# Author: Sisi Shao
# This file contains a helper function to select the soft-thresholding power
# for WGCNA network construction based on the scale-free topology criterion.
#===============================================================================

#' Select the best soft-thresholding power for WGCNA
#'
#' @description
#' Analyze scale-free topology fit across candidate powers and choose the
#' soft-thresholding power \eqn{\beta}.
#'
#' @details
#' The function uses a two-stage heuristic:
#' \enumerate{
#'   \item Choose the smallest \eqn{\beta} whose scale-free fit \eqn{R^2}
#'         exceeds \code{r2_threshold}.
#'   \item If no power reaches the target \eqn{R^2}, select the smallest power
#'         at the elbow of the curve, approximated by the discrete second
#'         derivative
#'         \deqn{\Delta^2 f(p_i) = f(p_{i+1}) - 2 f(p_i) + f(p_{i-1}),}
#'         and pick the \eqn{p_i} with the largest \eqn{|\Delta^2 f(p_i)|}.
#' }
#'
#' @param data_matrix Numeric matrix/data frame (rows = samples, cols = features).
#' @param r2_threshold Target \eqn{R^2} for the scale-free model (default 0.8).
#' @param make_plots Logical; if \code{TRUE}, write diagnostic PNGs.
#' @param output_dir Output directory for plots when \code{make_plots = TRUE}. Defaults to \code{tempdir()} to comply with CRAN policies.
#'
#' @return Integer scalar: the selected soft-thresholding power.
#'
#' @importFrom grDevices png dev.off
#' @importFrom graphics  abline
#' @export
#' @examples
#' # --- 1. Create a small synthetic data matrix ---
#' set.seed(123)
#' example_data <- matrix(rnorm(50 * 100), nrow = 50, ncol = 100)
#'
#' # --- 2. Run the function to get the selected power ---
#' best_power <- select_soft_power(example_data)
#' print(paste("Best soft power:", best_power))
#'

select_soft_power <- function(data_matrix, r2_threshold = 0.8, make_plots = FALSE, output_dir = tempdir()) {
    
    powers <- c(1:10, seq(from = 12, to = 20, by = 2))
    
    # Call pickSoftThreshold
    sft <- tryCatch({
        WGCNA::pickSoftThreshold(data_matrix, powerVector = powers, verbose = 0)
    }, error = function(e) {
        warning("pickSoftThreshold failed. Returning default power. Error: ", e$message)
        return(NULL)
    })
    
    if (is.null(sft) || is.null(sft$fitIndices)) {
        return(6) # Return a safe default if the function fails
    }
    
    fit_indices <- sft$fitIndices
    
    # Select power based on R^2 threshold
    power_r2 <- fit_indices$Power[which(fit_indices$SFT.R.sq > r2_threshold)[1]]
    
    selected_power <- 6 # Default power
    
    if (!is.na(power_r2)) {
        selected_power <- power_r2
        message(paste("Selected power by R^2 >=", r2_threshold, ":", selected_power))
    } else {
        # If no power meets the threshold, use the "elbow" method
        if (nrow(fit_indices) > 2) {
            second_deriv <- diff(diff(fit_indices$mean.k.))
            elbow_power <- fit_indices$Power[which.max(abs(second_deriv)) + 1]
            if (length(elbow_power) > 0) {
                selected_power <- elbow_power
                message(paste("R^2 threshold not met. Selected power by max curvature:", selected_power))
            }
        }
    }
    
    # Cap the power to avoid over-shrinking the network
    selected_power <- min(selected_power, 10)
    
    # Optionally save diagnostic plots
    if (make_plots) {
        if (!dir.exists(output_dir)) dir.create(output_dir, recursive = TRUE)
        
        # Plot R^2 vs Power
        png(file.path(output_dir, "scale_free_fit_plot.png"), width = 800, height = 600)
        plot(fit_indices[, "Power"], fit_indices[, "SFT.R.sq"],
             type = "b", col = "blue", pch = 20,
             xlab = "Soft Threshold (power)", ylab = "Scale-Free Topology Fit (R^2)",
             main = "Scale-Free Fit Index vs Power")
        abline(h = r2_threshold, col = "red", lty = 2)
        dev.off()
        
        # Plot mean connectivity
        png(file.path(output_dir, "mean_connectivity_plot.png"), width = 800, height = 600)
        plot(fit_indices[, "Power"], fit_indices[, "mean.k."],
             type = "b", col = "darkgreen", pch = 20,
             xlab = "Soft Threshold (power)", ylab = "Mean Connectivity",
             main = "Mean Connectivity vs Power")
        dev.off()
    }
    
    return(selected_power)
}
