set.seed(123)

test_that("predict.DIV returns vectors of the right length", {
  skip_if_no_torch()

  # Tiny simulated data so CRAN runs finish quickly
  n <- 80L
  pZ <- 2L
  pX <- 1L
  Z <- matrix(rnorm(n * pZ), ncol = pZ)
  H <- rnorm(n)
  X <- matrix(0.2 * Z[, 1] + 0.1 * Z[, 2] + H + rnorm(n, sd = 0.1), ncol = pX)
  Y <- 0.5 * X[, 1] + 0.1 * H^2 + rnorm(n, sd = 0.1)

  # Fit with very small network when on CRAN
  m <- div(
    Z = Z, X = X, Y = Y,
    epsx_dim = small_cfg$eps[1], epsy_dim = small_cfg$eps[2], epsh_dim = small_cfg$eps[3],
    hidden_dim = small_cfg$hidden, num_layer = 3L,
    num_epochs = small_cfg$epochs, lr = 1e-3, silent = TRUE
  )

  # Prepare test data
  Ztest <- matrix(rnorm(n * pZ), ncol = pZ)
  Ht <- rnorm(n)
  Xtest <- matrix(0.2 * Ztest[, 1] + 0.1 * Ztest[, 2] + Ht + rnorm(n, sd = 0.1), ncol = pX)

  # mean
  out_mean <- predict(m, Xtest = Xtest, type = "mean", drop = TRUE)
  expect_true(is.atomic(out_mean))
  expect_length(out_mean, n)

  # quantile single
  out_q05 <- predict(m, Xtest = Xtest, type = "quantile", quantiles = 0.05, drop = TRUE)
  expect_true(is.atomic(out_q05))
  expect_length(out_q05, n)

  # quantile multiple
  qs <- c(0.1, 0.5, 0.9)
  out_q <- predict(m, Xtest = Xtest, type = "quantile", quantiles = qs, drop = FALSE)
  expect_true(is.array(out_q))
  expect_identical(dim(out_q), c(n, 1L, length(qs)))
})

test_that("predict.DIV with W works and respects dimensions", {
  skip_if_no_torch()

  n <- 60L
  pZ <- 2L
  pX <- 1L
  pW <- 2L

  Z <- matrix(rnorm(n * pZ), ncol = pZ)
  W <- matrix(rnorm(n * pW), ncol = pW)
  H <- rnorm(n)
  X <- matrix(0.3 * Z[, 1] + 0.2 * W[, 1] + H + rnorm(n, sd = 0.1), ncol = pX)
  Y <- 0.4 * X[, 1] + 0.1 * W[, 2] + rnorm(n, sd = 0.1)

  m <- div(
    Z = Z, X = X, Y = Y, W = W,
    epsx_dim = small_cfg$eps[1], epsy_dim = small_cfg$eps[2], epsh_dim = small_cfg$eps[3],
    hidden_dim = small_cfg$hidden, num_layer = 3L,
    num_epochs = small_cfg$epochs, lr = 1e-3, silent = TRUE
  )

  Zt <- matrix(rnorm(n * pZ), ncol = pZ)
  Wt <- matrix(rnorm(n * pW), ncol = pW)
  Ht <- rnorm(n)
  Xt <- matrix(0.3 * Zt[, 1] + 0.2 * Wt[, 1] + Ht + rnorm(n, sd = 0.1), ncol = pX)

  out_mean <- predict(m, Xtest = Xt, Wtest = Wt, type = "mean", drop = TRUE)
  expect_true(is.atomic(out_mean))
  expect_length(out_mean, n)

  # sample returns an array or matrix
  out_samp <- predict(m, Xtest = Xt, Wtest = Wt, type = "sample", nsample = 3, drop = FALSE)
  expect_true(is.array(out_samp) || is.matrix(out_samp))
})
