balnet object.cv.balnet object.
-- B --
balnet()
balweights()
balweights.balnet()
balweights.cv.balnet()
-- C --
coef.balnet()
coef.cv.balnet()
cv.balnet()
-- P --
plot.balnet()
plot.cv.balnet()
predict.balnet()
predict.cv.balnet()
print.balnet()
print.cv.balnet()
Fits regularized logistic regression models using covariate balancing loss functions, targeting the ATE, ATT, or treated/control means.
balnet(
X,
W,
target = c("ATE", "ATT", "treated", "control"),
sample.weights = NULL,
max.imbalance = NULL,
nlambda = 100L,
lambda.min.ratio = 0.01,
lambda = NULL,
penalty.factor = NULL,
groups = NULL,
alpha = 1,
standardize = TRUE,
tol = 1e-07,
maxit = as.integer(1e+05),
verbose = FALSE,
num.threads = 1L,
...
)
X |
A numeric matrix or data frame with pre-treatment covariates. |
W |
Treatment vector (0 = control, 1 = treated). |
target |
The target estimand. Default is "ATE". |
sample.weights |
Optional sample weights. If |
max.imbalance |
Optional upper bound on the standardized covariate imbalance. For lasso penalization
( |
nlambda |
Number of values for |
lambda.min.ratio |
Ratio of smallest to largest lambda. Default is 1e-2. |
lambda |
Optional |
penalty.factor |
Penalty factor per feature. Default is 1 (i.e., each feature receives the same penalty). |
groups |
Optional list of group indices for group penalization. |
alpha |
Elastic net mixing parameter. Default is 1 (lasso), 0 corresponds to ridge. |
standardize |
Whether to standardize the input matrix. Should only be |
tol |
Coordinate descent convergence tolerance. Default is 1e-7. |
maxit |
Maximum number of coordinate descent iterations. Default is 1e5. |
verbose |
Whether to display information during fitting. Default is |
num.threads |
Number of threads to use. Default is 1. |
... |
Additional internal arguments passed to the solver. |
This function aims to find balancing weights \(\hat\gamma_i\), using logistic propensity scores, that balance covariate means to a target vector, i.e.,
$$\frac{1}{n} \sum_{i=1}^n \hat\gamma_i X_i = \bar X_{\mathrm{target}}.$$
With lasso regularization (alpha = 1), imbalance is controlled in the \(\ell_\infty\) sense,
allowing absolute slack of at most \(\lambda\) per covariate.
For target = "ATE", two logistic models are fit, one per arm, with
$$\hat\gamma_i^{(1)} = \frac{W_i}{\hat e^{(1)}(X_i)}, \quad \hat\gamma_i^{(0)} = \frac{1 - W_i}{1 - \hat e^{(0)}(X_i)}, \quad \bar X_{\mathrm{target}} = \frac{1}{n} \sum_{i=1}^n X_i.$$
\(\hat e^{(w)}(X_i)\) is the fitted propensity score for arm \(w\).
For target = "ATT", weights balance the control means:
$$\hat\gamma_i = (1 - W_i) \frac{\hat e^{(0)}(X_i)}{1 - \hat e^{(0)}(X_i)}, \quad \bar X_{\mathrm{target}} = \frac{1}{\sum W_i} \sum_{i=1}^n W_i X_i.$$
A fit balnet object.
Sverdrup, Erik and Trevor Hastie. "balnet: Pathwise Estimation of Covariate Balancing Propensity Scores". arXiv preprint, arXiv:2602.18577, 2026.
# Simulate data with confounding.
n <- 2000
p <- 10
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1.5 + exp(X[, 2] + X[, 3])))
Y <- W + 2 * log(1 + exp(X[, 1] + X[, 2] + X[, 3])) + rnorm(n)
# Fit model targeting the ATE = E[Y(1)] - E[Y(0)].
# Two logistic models are fit: one for treated, one for control.
fit <- balnet(X, W, target = "ATE")
# Print path summary.
print(fit)
#> Call: balnet(X = X, W = W, target = "ATE")
#>
#> Control (path: 100/100)
#> Nonzero Avg|SMD| Lambda
#> 1 0 0.047050 0.195447
#> 2 1 0.046026 0.186564
#> 3 1 0.045054 0.178084
#> ...
#> 98 9 0.002111 0.002145
#> 99 9 0.002022 0.002048
#> 100 9 0.001938 0.001954
#>
#> Treated (path: 100/100)
#> Nonzero Avg|SMD| Lambda
#> 1 0 0.083826 0.348217
#> 2 1 0.082751 0.332390
#> 3 1 0.081745 0.317282
#> ...
#> 98 9 0.003460 0.003822
#> 99 9 0.003309 0.003648
#> 100 9 0.003165 0.003482
# Visualize the path.
plot(fit)

# Plot the standardized covariate imbalance at given lambda.
# Note: lambda = 0 selects the final lambda in the sequence. Scalar values
# are applied to both arms.
plot(fit, lambda = 0)

# Predict propensity scores at end of lambda path.
W.hat <- predict(fit, X, lambda = 0)
# Get balancing weights at end of lambda path.
ipw.weights <- balweights(fit, lambda = 0)
# Estimate ATE using balancing weights.
mean(Y * (ipw.weights$treated - ipw.weights$control))
#> [1] 0.9165253
Retrieves the estimated balancing weights \(\hat{\gamma}\). Under unconfoundedness, these correspond to inverse probability weights (IPW) for standard treatment effect estimands.
balweights(object, lambda = NULL, ...)
## S3 method for class 'balnet'
balweights(object, lambda = NULL, ...)
## S3 method for class 'cv.balnet'
balweights(object, lambda = "lambda.min", ...)
object |
A |
lambda |
Value(s) of the penalty parameter
|
... |
Additional arguments (currently ignored). |
Estimated balancing weights
(for contrast fits, target = "ATE" or "ATT", returns a list with entries for each arm).
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
fit <- balnet(X, W, target = "ATT")
# Extract balancing weights.
wts <- balweights(fit, lambda = 0)
Extract coefficients from a balnet object.
## S3 method for class 'balnet'
coef(object, lambda = NULL, ...)
object |
A |
lambda |
Value(s) of the penalty parameter
|
... |
Additional arguments (currently ignored). |
Estimated logistic coefficients (for dual-arm fits, returns a list with entries for each arm).
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
fit <- balnet(X, W, target = "ATT")
# Extract coefficients.
coefs <- coef(fit)
Extract coefficients from a cv.balnet object.
## S3 method for class 'cv.balnet'
coef(object, lambda = "lambda.min", ...)
object |
A |
lambda |
The lambda to use. Defaults to the cross-validated lambda. |
... |
Additional arguments (currently ignored). |
Estimated logistic coefficients (for dual-arm fits, returns a list with entries for each arm).
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
cv.fit <- cv.balnet(X, W, target = "ATT")
# Extract coefficients at cross-validated lambda.
coefs <- coef(cv.fit)
Cross-validation for balnet.
cv.balnet(
X,
W,
type.measure = c("balance.loss"),
nfolds = 10,
foldid = NULL,
...
)
X |
A numeric matrix or data frame with pre-treatment covariates. |
W |
Treatment vector (0: control, 1: treated). |
type.measure |
The loss to minimize for cross-validation. Default is balance loss. |
nfolds |
The number of folds used for cross-validation, default is 10. |
foldid |
An optional |
... |
Arguments for |
A fit cv.balnet object.
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATE model.
cv.fit <- cv.balnet(X, W)
# Print CV summary.
print(cv.fit)
#> Call: cv.balnet(X = X, W = W)
#>
#> Cross-validated lambda minimizing type.measure = balance.loss:
#> Arm Nonzero Avg|SMD| Lambda Index
#> Control 1 0.08159 0.1905 8
#> Treated 15 0.09989 0.1327 31
# Plot at cross-validated lambda.
plot(cv.fit)

# Predict at cross-validated lambda.
W.hat <- predict(cv.fit, X)
balnet object.Shows effective sample size (ESS) and percent bias reduction (PBR; reduction
in mean absolute imbalance) along the regularization path, computed from balancing
weights and normalized to percentages. The right-hand axis maps these values
to the coefficient of variation (CV) of the weights.
Supplying the lambda argument displays the standardized covariate imbalance
\((\bar X_{\mathrm{weighted}} - \bar X_{\mathrm{target}}) / \sigma_{\mathrm{target}}\),
computed using the balancing weights at the specified lambda.
## S3 method for class 'balnet'
plot(x, lambda = NULL, groups = NULL, max = NULL, ...)
x |
A |
lambda |
If NULL (default) diagnostics over the lambda path is shown. Otherwise, covariate balance at provided lambda value is shown (if target = "ATE", lambda can be a 2-vector, arm 0 and arm 1.) |
groups |
Optional named list of contiguous covariate index ranges to
aggregate into a single variable before computing covariate imbalance
(e.g., |
max |
The number of covariates to display in covariate balance plot. Defaults to all covariates. |
... |
Additional arguments. |
Invisibly returns the information underlying the plot.
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
fit <- balnet(X, W, target = "ATT")
# Plot the five covariates with the largest unweighted imbalance
plot(fit, lambda = 0, max = 5)

cv.balnet object.Plot diagnostics for a cv.balnet object.
## S3 method for class 'cv.balnet'
plot(x, lambda = "lambda.min", ...)
x |
A |
lambda |
The lambda to use. Defaults to the cross-validated lambda. |
... |
Additional arguments. |
Invisibly returns the information underlying the plot.
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
cv.fit <- cv.balnet(X, W, target = "ATT")
# Plot at cross-validated lambda.
plot(cv.fit)

Predict using a balnet object.
## S3 method for class 'balnet'
predict(object, newdata, lambda = NULL, type = c("response"), ...)
object |
A |
newdata |
A numeric matrix. |
lambda |
Value(s) of the penalty parameter
|
type |
The type of predictions. Default is "response" (propensity scores). |
... |
Additional arguments (currently ignored). |
Estimated predictions (for dual-arm fits, returns a list with entries for each arm).
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
fit <- balnet(X, W, target = "ATT")
# Predict propensity scores.
W.hat <- predict(fit, X)
Predict using a cv.balnet object.
## S3 method for class 'cv.balnet'
predict(object, newdata, lambda = "lambda.min", type = c("response"), ...)
object |
A |
newdata |
A numeric matrix. |
lambda |
The lambda to use. Defaults to the cross-validated lambda. |
type |
The type of predictions. Default is "response" (propensity scores). |
... |
Additional arguments (currently ignored). |
Estimated predictions (for dual-arm fits, returns a list with entries for each arm).
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
cv.fit <- cv.balnet(X, W, target = "ATT")
# Predict propensity scores at cross-validated lambda.
W.hat <- predict(cv.fit, X)
Print a balnet object.
## S3 method for class 'balnet'
print(x, digits = max(3L, getOption("digits") - 3L), max = 3, ...)
x |
A |
digits |
Number of digits to print. |
max |
Total number of rows to show from the beginning and end of the path |
... |
Additional print arguments. |
Invisibly returns the printed information.
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
fit <- balnet(X, W, target = "ATT")
# Print path summary.
print(fit)
#> Call: balnet(X = X, W = W, target = "ATT")
#>
#> Control (path: 84/100)
#> Nonzero Avg|SMD| Lambda
#> 1 0 0.20579 0.89878
#> 2 1 0.20315 0.85793
#> 3 1 0.20067 0.81893
#> ...
#> 82 23 0.02072 0.02076
#> 83 23 0.01911 0.01982
#> 84 21 0.01730 0.01892
Print a cv.balnet object.
## S3 method for class 'cv.balnet'
print(x, digits = max(3L, getOption("digits") - 3L), ...)
x |
A |
digits |
Number of digits to print. |
... |
Additional print arguments. |
Invisibly returns the printed information.
n <- 100
p <- 25
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1])))
# Fit an ATT model.
cv.fit <- cv.balnet(X, W, target = "ATT")
# Print CV summary.
print(cv.fit)
#> Call: cv.balnet(X = X, W = W, target = "ATT")
#>
#> Cross-validated lambda minimizing type.measure = balance.loss:
#> Arm Nonzero Avg|SMD| Lambda Index
#> Control 2 0.1749 0.4733 21