Visualization¶
fairness_training includes a lightweight plotting module, fairness_training.viz, with three functions for the most common fairness paper figures. All functions require matplotlib, which is an optional dependency.
Setup¶
# Install with visualization support
pip install -e .[train,verify,viz]
# Or add matplotlib to an existing environment
pip install matplotlib
from fairness_training.viz import (
plot_training_history,
plot_group_distributions,
plot_fairness_tradeoff,
)
plot_training_history¶
Plot loss and fairness gap curves from the dict returned by trainer.fit().
history = trainer.fit(train_loader, val_loader=val_loader, epochs=50)
fig = plot_training_history(
history,
title="FairModel — mean_pred metric, ε=0.05",
)
What it shows: Two side-by-side panels — train/val loss on the left, train/val fairness gap on the right. The fairness gap panel makes it easy to see whether the model is staying within ε throughout training.
Arguments:
| Argument | Default | Description |
|---|---|---|
history |
required | Dict from trainer.fit() |
title |
"Training History" |
Figure title |
figsize |
(12, 4) |
(width, height) in inches |
save_path |
None |
Save to file instead of displaying |
Expected keys in history: train_loss, val_loss (optional), train_fairness_gap, val_fairness_gap (optional).
plot_group_distributions¶
Histogram of model predictions for each protected group. Useful for visually confirming that the fairness layer is working — the two group distributions should have similar means.
# Pass return_predictions=True to get predictions and group indicators
metrics = trainer.evaluate(test_loader, return_predictions=True)
fig = plot_group_distributions(
metrics['predictions'],
metrics['protected'],
attr_names={0: 'Gender', 1: 'Race'}, # optional human-readable names
title="Test-set prediction distributions after fairness training",
)
Arguments:
| Argument | Default | Description |
|---|---|---|
predictions |
required | Shape (n,) or (n, 1) array of model outputs |
protected_indicators |
required | Dict {attr_idx: binary 0/1 array} from metrics['protected'] |
attr_names |
None |
Dict {attr_idx: "Name"} for axis labels |
title |
"Prediction Distributions by Group" |
Figure title |
bins |
30 |
Number of histogram bins |
figsize |
auto | (6 * n_attrs, 4) by default |
save_path |
None |
Save to file instead of displaying |
plot_fairness_tradeoff¶
Scatter plot of loss vs fairness gap across multiple model configurations. This is the standard "fairness–accuracy tradeoff" figure from fairness papers.
import torch.nn as nn
import torch.optim as optim
epsilons = [0.01, 0.05, 0.10, 0.20]
results = {}
for eps in epsilons:
backbone = nn.Sequential(
nn.Linear(INPUT_DIM, 64), nn.ReLU(),
nn.Linear(64, 32), nn.ReLU(),
nn.Linear(32, 1), nn.Sigmoid(),
)
m = FairModel.wrap(backbone, protected_attr_idx=0, fairness_tolerance=eps)
t = FairTrainer(m, criterion, optim.Adam(m.parameters()))
t.fit(train_loader, epochs=30, verbose=False)
results[f"ε={eps}"] = t.evaluate(test_loader)
fig = plot_fairness_tradeoff(
results,
loss_key='test_loss',
gap_key='weighted_avg_fairness_gap',
title='Fairness–Accuracy Tradeoff (mean_pred)',
)
Arguments:
| Argument | Default | Description |
|---|---|---|
results |
required | Dict {label: metrics_dict} |
loss_key |
"test_loss" |
Key to use as x-axis |
gap_key |
"weighted_avg_fairness_gap" |
Key to use as y-axis |
title |
"Fairness–Accuracy Tradeoff" |
Figure title |
figsize |
(7, 5) |
(width, height) in inches |
save_path |
None |
Save to file instead of displaying |
Which gap metric to use on the y-axis
Use weighted_avg_fairness_gap for the y-axis — it's the quantity the theorem bounds, and it's what the model is actually optimising against. pooled_fairness_gap may look better or worse depending on batch group ratios, and doesn't directly correspond to the training objective.
Saving Figures¶
All three functions accept a save_path argument:
fig = plot_training_history(history, save_path="training_history.pdf")
fig = plot_group_distributions(preds, protected, save_path="group_dist.png")
fig = plot_fairness_tradeoff(results, save_path="tradeoff.pdf")
Supported formats: anything matplotlib supports — .png, .pdf, .svg, .eps.
Quickstart Notebook¶
The Quickstart Colab notebook shows all three functions with live output on a synthetic dataset.