FairTrainer¶
FairTrainer provides a high-level training interface for FairModel with built-in early stopping, learning rate scheduling, and fairness monitoring
fairness_training.FairTrainer
¶
Trainer for FairModel with fairness monitoring.
Handles the training loop including: - Training with hard per-batch fairness constraints - Automatic inference state reset before validation - Aggregate fairness tracking and reporting - Early stopping based on validation loss
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
FairModel instance to train |
required |
criterion
|
Module
|
Loss function (e.g., nn.MSELoss()) |
required |
optimizer
|
Optimizer
|
Optimizer (e.g., Adam) |
required |
device
|
str
|
Device to train on ('cpu' or 'cuda') |
'cpu'
|
scheduler
|
Optional[object]
|
Optional learning rate scheduler |
None
|
early_stopping_patience
|
int
|
Epochs to wait before early stopping |
25
|
early_stopping_delta
|
float
|
Minimum improvement to reset patience |
1e-05
|
fit(train_loader, val_loader=None, epochs=100, verbose=1, log_interval=10)
¶
Driver function that trains model.
Training uses hard per-batch fairness constraints. Validation uses the same batch size regime as the loader provides.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_loader
|
DataLoader
|
Training data loader (batch_size should be >= b_tau) |
required |
val_loader
|
Optional[DataLoader]
|
Validation data loader (optional) |
None
|
epochs
|
int
|
Maximum number of epochs |
100
|
verbose
|
int
|
Verbosity level (0=silent, 1=progress, 2=detailed) |
1
|
log_interval
|
int
|
Print every N epochs |
10
|
Returns:
| Type | Description |
|---|---|
Dict
|
Dictionary containing training history |
evaluate(test_loader, return_predictions=False)
¶
Evaluate model on test set.
Resets inference state before evaluation, then runs inference on all batches (using primal-dual for small batches if applicable).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
test_loader
|
DataLoader
|
Test data loader |
required |
return_predictions
|
bool
|
If True, return predictions array |
False
|
Returns:
| Type | Description |
|---|---|
Dict
|
Dictionary with test metrics including per-attribute fairness gaps |
save_checkpoint(filepath, include_history=True)
¶
Save model checkpoint including fairness hyperparameters.
load_checkpoint(filepath, load_optimizer=True)
¶
Load model checkpoint.
If the checkpoint contains a 'model_config' entry, its values are compared against the current model's configuration and a warning is issued for any mismatch so that silent config drift is caught early.
See Also¶
- FairModel - The model class
- Training Guide - Best practices
- Utilities - Data loading helpers