Skip to content

FairModel

FairModel is the core class that implements a neural network with differentiable fairness constraints. It wraps a standard feedforward network with a cvxpylayers optimization layer that projects predictions onto the fairness constraint set.


fairness_training.FairModel

Bases: Module

Neural network with differentiable fairness constraints.

This model projects predictions from a standard neural network through a cvxpylayers optimization layer that enforces marginal fairness constraints while minimizing distortion from the original predictions.

Marginal Fairness: For each protected attribute independently, the model ensures constraints across groups

Training: Always uses hard per-batch constraints (batch_size should be >= b_tau) Inference: Uses hard constraints if batch_size >= b_tau, otherwise uses online primal-dual algorithm

Parameters:

Name Type Description Default
input_dim int

Number of input features

required
hidden_dims List[int]

List of hidden layer dimensions (optional if custom_network provided)

None
output_dim int

Output dimension (typically 1 for binary classification)

1
protected_attr_idx Union[int, List[int]]

Index or list of indices of protected attribute columns

0
prediction_bounds Tuple[float, float]

(lower, upper) bounds for predictions

(0.0, 1.0)
fairness_tolerance float

Target fairness tolerance epsilon

0.05
b_tau int

Batch size threshold for inference - above uses hard constraints, below uses primal-dual

64
eta_0 float

Initial dual step size for inference primal-dual updates

0.5
activation str

Activation function ('relu', 'tanh', 'sigmoid', 'leaky_relu')

'relu'
fairness_metric Union[str, FairnessMetric]

Either a string ('mean_pred', 'mean_residual') or a FairnessMetric instance

'mean_pred'
custom_network Module

Optional custom network architecture to use instead of default

None

forward(x, y=None, inference=False)

Forward pass that routes to inference or training implementations.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, input_dim)

required
y Optional[Tensor]

Target tensor (required if fairness_metric.requires_targets is True)

None
inference bool

If True, use inference mode (may use primal-dual for small batches)

False

Returns:

Type Description
Tensor

Fair predictions of shape (batch_size, output_dim)

wrap(network, protected_attr_idx, input_dim, fairness_tolerance=0.05, fairness_metric='mean_pred', prediction_bounds=None, b_tau=64, eta_0=0.5, exclude_protected_from_backbone=False) classmethod

Wrap an existing nn.Module with a differentiable fairness layer.

Infers output_dim from a dry-run forward pass. If prediction_bounds is not provided, it is inferred from the dry-run output range (with a small margin) and a warning is issued so the user can verify or override.

Parameters:

Name Type Description Default
network Module

Any nn.Module that accepts (batch_size, input_dim) input and returns (batch_size, output_dim) output.

required
protected_attr_idx Union[int, List[int]]

Column index or list of indices of protected attributes in the input tensor.

required
input_dim int

Number of input features (columns in X).

required
fairness_tolerance float

Target fairness tolerance epsilon.

0.05
fairness_metric Union[str, FairnessMetric]

Metric name string or FairnessMetric instance.

'mean_pred'
prediction_bounds Optional[Tuple[float, float]]

(lb, ub) for the fairness layer. If None, inferred from the network's output on a probe input.

None
b_tau int

Batch size threshold between hard-constraint and primal-dual inference. Batches >= b_tau use hard per-batch constraints; smaller batches use the online primal-dual algorithm.

64
eta_0 float

Initial dual step size for primal-dual inference.

0.5
exclude_protected_from_backbone bool

If True, wrap network in a thin layer that strips the protected attribute columns before forwarding, so the backbone never sees the protected attribute as a feature. The fairness layer still receives the full input including protected columns. Defaults to False for backward compatibility.

False

Returns:

Type Description
FairModel

Initialized FairModel using network as its backbone.

Example::

import torch.nn as nn
from fairness_training import FairModel

my_net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())
model = FairModel.wrap(my_net, protected_attr_idx=0, input_dim=20)

# Prevent backbone from seeing the protected attribute as a feature:
model = FairModel.wrap(my_net, protected_attr_idx=0, input_dim=20,
                       exclude_protected_from_backbone=True)

reset_inference_state()

Reset primal-dual state for inference. Call before new inference sequence

get_aggregate_fairness_stats(data_loader, reset_before=True)

Compute aggregate fairness statistics over a data loader

Uses inference mode with primal-dual algorithm for small batches.

Parameters:

Name Type Description Default
data_loader

PyTorch DataLoader

required
reset_before bool

If True, reset inference state before evaluation

True

Returns:

Type Description
dict

Dictionary with aggregate statistics


See Also