Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Generated by roxygen2: do not edit by hand

S3method("$",loo)
S3method("$",loo_diagnostics)
S3method("$<-",loo_diagnostics)
S3method("[",loo)
S3method("[",loo_diagnostics)
S3method("[[",loo)
S3method("[[",loo_diagnostics)
S3method("[[<-",loo_diagnostics)
S3method(.compute_point_estimate,default)
S3method(.compute_point_estimate,matrix)
S3method(.ndraws,default)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# loo (development version)

* Rename `diagnostics$n_eff` to `diagnostics$ess`. The `diagnostics` list now has class `"loo_diagnostics"`, with custom `$`, `[[`, and `[` methods that emit a deprecation warning when `n_eff` is accessed. The internal column name in `pareto_k_table` is now `"Min. ESS"`. (#192)

# loo 2.9.0

* Avoid under and overflows in stacking by @avehtari in #273
Expand Down
99 changes: 90 additions & 9 deletions R/diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ NULL
#' @export
#' @return `pareto_k_table()` returns an object of class
#' `"pareto_k_table"`, which is a matrix with columns `"Count"`,
#' `"Proportion"`, and `"Min. n_eff"`, and has its own print method.
#' `"Proportion"`, and `"Min. ESS"`, and has its own print method.
#'
pareto_k_table <- function(x) {
k <- pareto_k_values(x)
Expand All @@ -132,7 +132,7 @@ pareto_k_table <- function(x) {
out <- cbind(
Count = count,
Proportion = prop.table(count),
"Min. n_eff" = min_n_eff
"Min. ESS" = min_n_eff
)
attr(out, "k_threshold") <- k_threshold
structure(out, class = c("pareto_k_table", class(out)))
Expand All @@ -152,8 +152,7 @@ print.pareto_k_table <- function(x, digits = 1, ...) {
" " = c("(good)", "(bad)", "(very bad)"),
"Count" = .fr(count, 0),
"Pct. " = paste0(.fr(100 * x[, "Proportion"], digits), "%"),
# Print ESS as n_eff terms has been deprecated
"Min. ESS" = round(x[, "Min. n_eff"])
"Min. ESS" = round(x[, "Min. ESS"])
)
tab2 <- rbind(tab)
cat("Pareto k diagnostic values:\n")
Expand Down Expand Up @@ -214,12 +213,14 @@ pareto_k_influence_values <- function(x) {
#' @return `psis_n_eff_values()` returns a vector of the estimated PSIS
#' effective sample sizes.
psis_n_eff_values <- function(x) {
n_eff <- x$diagnostics[["n_eff"]]
if (is.null(n_eff)) {
# Print ESS as n_eff terms has been deprecated
stop("No PSIS ESS estimates found.", call. = FALSE)
diag <- unclass(x$diagnostics)
if (!is.null(diag[["ess"]])) {
return(diag[["ess"]])
}
if (!is.null(diag[["n_eff"]])) {
return(diag[["n_eff"]])
}
return(n_eff)
stop("No PSIS ESS estimates found.", call. = FALSE)
}

#' @rdname pareto-k-diagnostic
Expand Down Expand Up @@ -438,3 +439,83 @@ min_n_eff_by_k <- function(n_eff, kcut) {
ps_khat_threshold <- function(S, ...) {
min(1 - 1 / log10(S), 0.7)
}


# loo_diagnostics class --------------------------------------------------

#' Create a diagnostics list with class `"loo_diagnostics"`
#'
#' @noRd
#' @param pareto_k Vector of Pareto k estimates.
#' @param ess Vector of PSIS effective sample size estimates.
#' @param r_eff Vector of relative MCMC effective sample sizes (optional).
#' @return A list with class `"loo_diagnostics"` containing `pareto_k`, `ess`,
#' `n_eff` (kept for backward compatibility), and optionally `r_eff`.
loo_diagnostics <- function(pareto_k, ess, r_eff = NULL) {
out <- list(pareto_k = pareto_k, ess = ess, n_eff = ess, r_eff = r_eff)
structure(out, class = "loo_diagnostics")
}

#' @export
`$.loo_diagnostics` <- function(x, name) {
if (identical(name, "n_eff")) {
warning(
"Accessing 'n_eff' using '$' is deprecated. ",
"Please use 'ess' instead.",
call. = FALSE
)
}
NextMethod()
}

#' @export
`[[.loo_diagnostics` <- function(x, i, exact = TRUE) {
if (is.character(i) && identical(i, "n_eff")) {
warning(
"Accessing 'n_eff' using '[[' is deprecated. ",
"Please use 'ess' instead.",
call. = FALSE
)
}
NextMethod()
}

#' @export
`[.loo_diagnostics` <- function(x, i) {
if (is.character(i) && identical(i, "n_eff")) {
warning(
"Accessing 'n_eff' using '[' is deprecated. ",
"Please use 'ess' instead.",
call. = FALSE
)
}
NextMethod()
}

#' @export
`$<-.loo_diagnostics` <- function(x, name, value) {
if (identical(name, "n_eff")) {
x <- unclass(x)
x[["n_eff"]] <- value
x[["ess"]] <- value
return(structure(x, class = "loo_diagnostics"))
}
if (identical(name, "ess")) {
x <- unclass(x)
x[["ess"]] <- value
x[["n_eff"]] <- value
return(structure(x, class = "loo_diagnostics"))
}
NextMethod()
}

#' @export
`[[<-.loo_diagnostics` <- function(x, i, value) {
if (is.character(i) && i %in% c("n_eff", "ess")) {
x <- unclass(x)
x[["n_eff"]] <- value
x[["ess"]] <- value
return(structure(x, class = "loo_diagnostics"))
}
NextMethod()
}
5 changes: 3 additions & 2 deletions R/importance_sampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ importance_sampling_object <-
out <- structure(
list(
log_weights = unnormalized_log_weights,
diagnostics = list(pareto_k = pareto_k, n_eff = NULL, r_eff = r_eff)
diagnostics = loo_diagnostics(pareto_k = pareto_k, ess = NULL, r_eff = r_eff)
),
# attributes
norm_const_log = norm_const_log,
Expand All @@ -166,7 +166,8 @@ importance_sampling_object <-

# need normalized weights (not on log scale) for psis_n_eff
w <- weights(out, normalize = TRUE, log = FALSE)
out$diagnostics[["n_eff"]] <- psis_n_eff(w, r_eff)
ess_val <- psis_n_eff(w, r_eff)
out$diagnostics <- loo_diagnostics(pareto_k = pareto_k, ess = ess_val, r_eff = r_eff)
return(out)
}

Expand Down
14 changes: 8 additions & 6 deletions R/loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@
#' }
#'
#' \item{`diagnostics`}{
#' A named list containing two vectors:
#' A named list (of class `"loo_diagnostics"`) containing:
#' * `pareto_k`: Importance sampling reliability diagnostics. By default,
#' these are equal to the `influence_pareto_k` in `pointwise`.
#' Some algorithms can improve importance sampling reliability and
#' modify these diagnostics. See the [pareto-k-diagnostic] page for details.
#' * `n_eff`: PSIS effective sample size estimates.
#' * `ess`: PSIS effective sample size estimates.
#' * `n_eff`: Deprecated alias for `ess`. Accessing `n_eff` will
#' produce a deprecation warning.
#' }
#'
#' \item{`psis_object`}{
Expand Down Expand Up @@ -286,9 +288,9 @@ loo.function <-
diagnostics <- psis_out$diagnostics
} else {
diagnostics_list <- lapply(psis_list, "[[", "diagnostics")
diagnostics <- list(
diagnostics <- loo_diagnostics(
pareto_k = psis_apply(diagnostics_list, "pareto_k"),
n_eff = psis_apply(diagnostics_list, "n_eff"),
ess = psis_apply(diagnostics_list, "ess"),
r_eff = psis_apply(diagnostics_list, "r_eff")
)
}
Expand Down Expand Up @@ -545,9 +547,9 @@ list2importance_sampling <- function(objects) {
structure(
list(
log_weights = log_weights,
diagnostics = list(
diagnostics = loo_diagnostics(
pareto_k = psis_apply(diagnostics, item = "pareto_k"),
n_eff = psis_apply(diagnostics, item = "n_eff"),
ess = psis_apply(diagnostics, item = "ess"),
r_eff = psis_apply(diagnostics, item = "r_eff")
)
),
Expand Down
4 changes: 2 additions & 2 deletions R/loo_approximate_posterior.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ loo_approximate_posterior.function <-
diagnostics <- psis_out$diagnostics
} else {
diagnostics_list <- lapply(psis_list, "[[", "diagnostics")
diagnostics <- list(
diagnostics <- loo_diagnostics(
pareto_k = psis_apply(diagnostics_list, "pareto_k"),
n_eff = psis_apply(diagnostics_list, "n_eff")
ess = psis_apply(diagnostics_list, "ess")
)
}

Expand Down
2 changes: 1 addition & 1 deletion R/loo_moment_matching.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ loo_moment_match.default <- function(x, loo, post_draws, log_lik_i,
loo$pointwise[i, "looic"] <- mm_list[[ii]]$looic

loo$diagnostics$pareto_k[i] <- mm_list[[ii]]$k
loo$diagnostics$n_eff[i] <- mm_list[[ii]]$n_eff
loo$diagnostics[["ess"]][i] <- mm_list[[ii]]$n_eff
kfs[i] <- mm_list[[ii]]$kf

if (!is.null(loo$psis_object)) {
Expand Down
33 changes: 23 additions & 10 deletions R/loo_subsample.R
Original file line number Diff line number Diff line change
Expand Up @@ -981,10 +981,15 @@ rbind_psis_loo_ss <- function(object, x) {
checkmate::assert_disjunct(object$pointwise[, "idx"], x$pointwise[, "idx"])

object$pointwise <- rbind(object$pointwise, x$pointwise)
object$diagnostics$pareto_k <-
c(object$diagnostics$pareto_k, x$diagnostics$pareto_k)
object$diagnostics$n_eff <- c(object$diagnostics$n_eff, x$diagnostics$n_eff)
object$diagnostics$r_eff <- c(object$diagnostics$r_eff, x$diagnostics$r_eff)
new_diag <- unclass(object$diagnostics)
x_diag <- unclass(x$diagnostics)
new_ess <- if (!is.null(new_diag$ess)) new_diag$ess else new_diag$n_eff
x_ess <- if (!is.null(x_diag$ess)) x_diag$ess else x_diag$n_eff
object$diagnostics <- loo_diagnostics(
pareto_k = c(new_diag$pareto_k, x_diag$pareto_k),
ess = c(new_ess, x_ess),
r_eff = c(new_diag$r_eff, x_diag$r_eff)
)
attr(object, "dims")[2] <- nrow(object$pointwise)
object
}
Expand All @@ -1006,9 +1011,13 @@ remove_idx.psis_loo_ss <- function(object, idxs) {
row_map <- merge(row_map, idxs, by = "idx", all.y = TRUE)

object$pointwise <- object$pointwise[-row_map$row_no,,drop = FALSE]
object$diagnostics$pareto_k <- object$diagnostics$pareto_k[-row_map$row_no]
object$diagnostics$n_eff <- object$diagnostics$n_eff[-row_map$row_no]
object$diagnostics$r_eff <- object$diagnostics$r_eff[-row_map$row_no]
d <- unclass(object$diagnostics)
d_ess <- if (!is.null(d$ess)) d$ess else d$n_eff
object$diagnostics <- loo_diagnostics(
pareto_k = d$pareto_k[-row_map$row_no],
ess = d_ess[-row_map$row_no],
r_eff = d$r_eff[-row_map$row_no]
)
attr(object, "dims")[2] <- nrow(object$pointwise)
object
}
Expand All @@ -1028,9 +1037,13 @@ order.psis_loo_ss <- function(x, observations) {
row_map_obs <- data.frame(row_no_obs = 1:length(observations), idx = observations)
row_map <- merge(row_map_obs, row_map_x, by = "idx", sort = FALSE)
x$pointwise <- x$pointwise[row_map$row_no_x,,drop = FALSE]
x$diagnostics$pareto_k <- x$diagnostics$pareto_k[row_map$row_no_x]
x$diagnostics$n_eff <- x$diagnostics$n_eff[row_map$row_no_x]
x$diagnostics$r_eff <- x$diagnostics$r_eff[row_map$row_no_x]
d <- unclass(x$diagnostics)
d_ess <- if (!is.null(d$ess)) d$ess else d$n_eff
x$diagnostics <- loo_diagnostics(
pareto_k = d$pareto_k[row_map$row_no_x],
ess = d_ess[row_map$row_no_x],
r_eff = d$r_eff[row_map$row_no_x]
)
x
}

Expand Down
6 changes: 4 additions & 2 deletions R/psis.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
#' class `"psis"`.
#' }
#' \item{`diagnostics`}{
#' A named list containing two vectors:
#' A named list (of class `"loo_diagnostics"`) containing:
#' * `pareto_k`: Estimates of the shape parameter \eqn{k} of the
#' generalized Pareto distribution. See the [pareto-k-diagnostic]
#' page for details.
#' * `n_eff`: PSIS effective sample size estimates.
#' * `ess`: PSIS effective sample size estimates.
#' * `n_eff`: Deprecated alias for `ess`. Accessing `n_eff` will
#' produce a deprecation warning.
#' }
#' }
#'
Expand Down
3 changes: 2 additions & 1 deletion R/sis.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
#' \item{`diagnostics`}{
#' A named list containing one vector:
#' * `pareto_k`: Not used in `sis`, all set to 0.
#' * `n_eff`: effective sample size estimates.
#' * `ess`: Effective sample size estimates.
#' * `n_eff`: Deprecated alias for `ess`.
#' }
#' }
#'
Expand Down
3 changes: 2 additions & 1 deletion R/tis.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
#' \item{`diagnostics`}{
#' A named list containing one vector:
#' * `pareto_k`: Not used in `tis`, all set to 0.
#' * `n_eff`: Effective sample size estimates.
#' * `ess`: Effective sample size estimates.
#' * `n_eff`: Deprecated alias for `ess`.
#' }
#' }
#'
Expand Down
Loading