Skip to content

FairTrainer

FairTrainer provides a high-level training interface for FairModel with built-in early stopping, learning rate scheduling, and fairness monitoring


fairness_training.FairTrainer

Trainer for FairModel with fairness monitoring.

Handles the training loop including: - Training with hard per-batch fairness constraints - Automatic inference state reset before validation - Aggregate fairness tracking and reporting - Early stopping based on validation loss

Parameters:

Name Type Description Default
model Module

FairModel instance to train

required
criterion Module

Loss function (e.g., nn.MSELoss())

required
optimizer Optimizer

Optimizer (e.g., Adam)

required
device str

Device to train on ('cpu' or 'cuda')

'cpu'
scheduler Optional[object]

Optional learning rate scheduler

None
early_stopping_patience int

Epochs to wait before early stopping

25
early_stopping_delta float

Minimum improvement to reset patience

1e-05

fit(train_loader, val_loader=None, epochs=100, verbose=1, log_interval=10)

Driver function that trains model.

Training uses hard per-batch fairness constraints. Validation uses the same batch size regime as the loader provides.

Parameters:

Name Type Description Default
train_loader DataLoader

Training data loader (batch_size should be >= b_tau)

required
val_loader Optional[DataLoader]

Validation data loader (optional)

None
epochs int

Maximum number of epochs

100
verbose int

Verbosity level (0=silent, 1=progress, 2=detailed)

1
log_interval int

Print every N epochs

10

Returns:

Type Description
Dict

Dictionary containing training history

evaluate(test_loader, return_predictions=False)

Evaluate model on test set.

Resets inference state before evaluation, then runs inference on all batches (using primal-dual for small batches if applicable).

Parameters:

Name Type Description Default
test_loader DataLoader

Test data loader

required
return_predictions bool

If True, return predictions array

False

Returns:

Type Description
Dict

Dictionary with test metrics including per-attribute fairness gaps

save_checkpoint(filepath, include_history=True)

Save model checkpoint including fairness hyperparameters.

load_checkpoint(filepath, load_optimizer=True)

Load model checkpoint.

If the checkpoint contains a 'model_config' entry, its values are compared against the current model's configuration and a warning is issued for any mismatch so that silent config drift is caught early.


See Also