FairModel¶
FairModel is the core class that implements a neural network with differentiable fairness constraints. It wraps a standard feedforward network with a cvxpylayers optimization layer that projects predictions onto the fairness constraint set.
fairness_training.FairModel
¶
Bases: Module
Neural network with differentiable fairness constraints.
This model projects predictions from a standard neural network through a cvxpylayers optimization layer that enforces marginal fairness constraints while minimizing distortion from the original predictions.
Marginal Fairness: For each protected attribute independently, the model ensures constraints across groups
Training: Always uses hard per-batch constraints (batch_size should be >= b_tau) Inference: Uses hard constraints if batch_size >= b_tau, otherwise uses online primal-dual algorithm
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Number of input features |
required |
hidden_dims
|
List[int]
|
List of hidden layer dimensions (optional if custom_network provided) |
None
|
output_dim
|
int
|
Output dimension (typically 1 for binary classification) |
1
|
protected_attr_idx
|
Union[int, List[int]]
|
Index or list of indices of protected attribute columns |
0
|
prediction_bounds
|
Tuple[float, float]
|
(lower, upper) bounds for predictions |
(0.0, 1.0)
|
fairness_tolerance
|
float
|
Target fairness tolerance epsilon |
0.05
|
b_tau
|
int
|
Batch size threshold for inference - above uses hard constraints, below uses primal-dual |
64
|
eta_0
|
float
|
Initial dual step size for inference primal-dual updates |
0.5
|
activation
|
str
|
Activation function ('relu', 'tanh', 'sigmoid', 'leaky_relu') |
'relu'
|
fairness_metric
|
Union[str, FairnessMetric]
|
Either a string ('mean_pred', 'mean_residual') or a FairnessMetric instance |
'mean_pred'
|
custom_network
|
Module
|
Optional custom network architecture to use instead of default |
None
|
forward(x, y=None, inference=False)
¶
Forward pass that routes to inference or training implementations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, input_dim) |
required |
y
|
Optional[Tensor]
|
Target tensor (required if fairness_metric.requires_targets is True) |
None
|
inference
|
bool
|
If True, use inference mode (may use primal-dual for small batches) |
False
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Fair predictions of shape (batch_size, output_dim) |
wrap(network, protected_attr_idx, input_dim, fairness_tolerance=0.05, fairness_metric='mean_pred', prediction_bounds=None, b_tau=64, eta_0=0.5, exclude_protected_from_backbone=False)
classmethod
¶
Wrap an existing nn.Module with a differentiable fairness layer.
Infers output_dim from a dry-run forward pass. If
prediction_bounds is not provided, it is inferred from the
dry-run output range (with a small margin) and a warning is issued
so the user can verify or override.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
network
|
Module
|
Any |
required |
protected_attr_idx
|
Union[int, List[int]]
|
Column index or list of indices of protected attributes in the input tensor. |
required |
input_dim
|
int
|
Number of input features (columns in X). |
required |
fairness_tolerance
|
float
|
Target fairness tolerance epsilon. |
0.05
|
fairness_metric
|
Union[str, FairnessMetric]
|
Metric name string or |
'mean_pred'
|
prediction_bounds
|
Optional[Tuple[float, float]]
|
|
None
|
b_tau
|
int
|
Batch size threshold between hard-constraint and primal-dual inference. Batches >= b_tau use hard per-batch constraints; smaller batches use the online primal-dual algorithm. |
64
|
eta_0
|
float
|
Initial dual step size for primal-dual inference. |
0.5
|
exclude_protected_from_backbone
|
bool
|
If True, wrap |
False
|
Returns:
| Type | Description |
|---|---|
FairModel
|
Initialized |
Example::
import torch.nn as nn
from fairness_training import FairModel
my_net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())
model = FairModel.wrap(my_net, protected_attr_idx=0, input_dim=20)
# Prevent backbone from seeing the protected attribute as a feature:
model = FairModel.wrap(my_net, protected_attr_idx=0, input_dim=20,
exclude_protected_from_backbone=True)
reset_inference_state()
¶
Reset primal-dual state for inference. Call before new inference sequence
get_aggregate_fairness_stats(data_loader, reset_before=True)
¶
Compute aggregate fairness statistics over a data loader
Uses inference mode with primal-dual algorithm for small batches.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_loader
|
PyTorch DataLoader |
required | |
reset_before
|
bool
|
If True, reset inference state before evaluation |
True
|
Returns:
| Type | Description |
|---|---|
dict
|
Dictionary with aggregate statistics |
See Also¶
- FairTrainer - Training utilities
- FairnessMetric - Fairness metric base class
- Core Concepts - Theory and background