Linear readability vs nonlinear heads — with conformal diagnostics

This notebook is intentionally thin. The experiment lives in linread_heads_core.py; this notebook only uploads/runs it and displays the output tables/figures.

# Runtime check
import torch, sklearn, sys, platform
print("Python:", sys.version)
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu")
print("sklearn:", sklearn.__version__)
# Upload core script if it is not already present in /content.
from pathlib import Path

core = Path('/content/linread_heads_core_v4.py')
if not core.exists():
    try:
        from google.colab import files
        print('Upload linread_heads_core_v4.py')
        uploaded = files.upload()
        if 'linread_heads_core_v4.py' not in uploaded:
            # Accept another Python file too, but save as v4.
            for name in uploaded:
                if name.endswith('.py'):
                    Path('/content/linread_heads_core_v4.py').write_bytes(uploaded[name])
                    break
    except Exception as e:
        print('Upload helper failed:', e)

assert core.exists(), 'linread_heads_core_v4.py not found. Upload it and rerun this cell.'
print(core, 'ready')
DATASET = 'mnist'
OUT_DIR = '/content/runs/linread_heads_mnist_v4'
SEED = 43
EPOCHS = 20
SAVE_EPOCHS = '0,1,2,5,10,20'

# Use the full MNIST test split for cleaner conformal calibration/evaluation.
# Train features can be 5000-10000 for speed; larger is more stable.
PROBE_TRAIN_SAMPLES = 10000
PROBE_TEST_SAMPLES = 10000

# Conformal: rank gives nonempty sets; holdout splits the held-out/test pool into calib/eval.
CONFORMAL_SCORE = 'rank'       # 'rank' or 'lac'
CALIBRATION_SOURCE = 'holdout' # 'holdout' or 'train'
HOLDOUT_CALIB_FRACTION = 0.50
COVERAGE_TOLERANCE = 0.01
# Train backbone, save checkpoints, then run head and conformal analyses.
# If is is trained into OUT_DIR, change EPOCHS to 0 and add --posthoc-only in the command below.
cmd = f"""
python /content/linread_heads_core_v4.py \
  --dataset {DATASET} \
  --out-dir {OUT_DIR} \
  --seed {SEED} \
  --epochs {EPOCHS} \
  --save-epochs {SAVE_EPOCHS} \
  --batch-size 256 \
  --eval-batch-size 512 \
  --lr 1e-3 \
  --weight-decay 0.05 \
  --label-smoothing 0.05 \
  --img-size 32 \
  --patch-size 4 \
  --embed-dim 64 \
  --num-layers 4 \
  --num-heads 4 \
  --probe-train-samples {PROBE_TRAIN_SAMPLES} \
  --probe-test-samples {PROBE_TEST_SAMPLES} \
  --bootstrap-samples 1000 \
  --rf-trees 300 \
  --conformal-score {CONFORMAL_SCORE} \
  --conformal-calibration-source {CALIBRATION_SOURCE} \
  --conformal-holdout-calib-fraction {HOLDOUT_CALIB_FRACTION} \
  --coverage-tolerance {COVERAGE_TOLERANCE}
"""
print(cmd)
!{cmd}

Recovery / posthoc-only mode

If the training finished but analysis crashed or you changed only conformal settings, run the next cell instead of retraining. It reuses checkpoints in OUT_DIR/checkpoints.

# Optional: rerun analysis only from existing checkpoints.
# Uncomment and run when needed.
# cmd = f"""
# python /content/linread_heads_core_v4.py \
#   --dataset {DATASET} \
#   --out-dir {OUT_DIR} \
#   --seed {SEED} \
#   --epochs 0 \
#   --posthoc-only \
#   --save-epochs {SAVE_EPOCHS} \
#   --batch-size 256 \
#   --eval-batch-size 512 \
#   --img-size 32 \
#   --patch-size 4 \
#   --embed-dim 64 \
#   --num-layers 4 \
#   --num-heads 4 \
#   --probe-train-samples {PROBE_TRAIN_SAMPLES} \
#   --probe-test-samples {PROBE_TEST_SAMPLES} \
#   --bootstrap-samples 1000 \
#   --rf-trees 300 \
#   --conformal-score {CONFORMAL_SCORE} \
#   --conformal-calibration-source {CALIBRATION_SOURCE} \
#   --conformal-holdout-calib-fraction {HOLDOUT_CALIB_FRACTION} \
#   --coverage-tolerance {COVERAGE_TOLERANCE}
# """
# print(cmd)
# !{cmd}
# Load outputs
from pathlib import Path
import pandas as pd
from IPython.display import display, Image, Markdown

out = Path(OUT_DIR)
results = pd.read_csv(out / 'head_results.csv')

# writes readability diagnostics here. Keep a fallback for older runs.
sep_path = out / 'readability_metrics.csv'
old_geom_path = out / 'geometry_results.csv'
if sep_path.exists():
    geom = pd.read_csv(sep_path)
elif old_geom_path.exists():
    geom = pd.read_csv(old_geom_path)
else:
    geom = pd.DataFrame()
    print('No readability/geometry metrics CSV found.')

conf = pd.read_csv(out / 'conformal_results.csv') if (out / 'conformal_results.csv').exists() else pd.DataFrame()

def no_error_mask(df):
    """CSV empty strings often round-trip as NaN; treat both as no error."""
    if 'error' not in df.columns:
        return pd.Series(True, index=df.index)
    return df['error'].isna() | (df['error'].astype(str).eq(''))

print('outputs:', out)
display(results.tail())
if not geom.empty:
    print('readability metrics:')
    display(geom.tail())
if not conf.empty:
    print('conformal results:')
    display(conf.tail())
# Final checkpoint: accuracy deltas with paired uncertainty
final_epoch = int(results['epoch'].max())
final = results[results['epoch'] == final_epoch].sort_values('test_acc', ascending=False)
print('Final checkpoint:', final_epoch)

cols = ['head','head_family','test_acc','delta_vs_linear','delta_ci_low','delta_ci_high',
        'relative_gain_of_remaining_error','mcnemar_n01','mcnemar_n10','mcnemar_p']
print('Deltas versus logistic regression baseline:')
display(final[[c for c in cols if c in final.columns]])

best_cols = ['head','head_family','test_acc','best_linear_head','best_linear_acc',
             'delta_vs_best_linear','delta_best_linear_ci_low','delta_best_linear_ci_high',
             'relative_gain_vs_best_linear_remaining_error',
             'mcnemar_vs_best_linear_n01','mcnemar_vs_best_linear_n10','mcnemar_vs_best_linear_p']
if 'delta_vs_best_linear' in final.columns:
    print('Deltas versus the best linear head at this checkpoint:')
    display(final[[c for c in best_cols if c in final.columns]])
# Final checkpoint: conformal diagnostics.
# Important: sort valid rows first. A smaller set size is not meaningful if coverage is below target.
if not conf.empty:
    final_conf = conf[(conf['epoch'] == final_epoch) & no_error_mask(conf)].copy()
    if len(final_conf) == 0:
        print('No successful conformal rows for final checkpoint.')
    else:
        final_conf['valid_at_target'] = final_conf['valid_at_target'].astype(bool)
        final_conf['valid_sort'] = ~final_conf['valid_at_target']
        cols = ['head','head_family','selection_acc','test_acc','coverage_target','coverage','coverage_gap',
                'avg_set_size','median_set_size','singleton_rate','empty_rate','valid_at_target',
                'selected_overall','selected_nonlinear','conformal_score','calibration_source']
        final_conf = final_conf.sort_values(['valid_sort','avg_set_size','coverage'], ascending=[True, True, False])
        display(final_conf[[c for c in cols if c in final_conf.columns]])
# Head selected on the selection split at each checkpoint
if not conf.empty:
    selected = conf[(conf['selected_overall'] == True) & no_error_mask(conf)].copy()
    if len(selected) == 0:
        print('No selected-head conformal rows found. Check the conformal results table above.')
    else:
        display(selected[['epoch','head','selection_acc','test_acc','coverage','avg_set_size','singleton_rate','empty_rate','valid_at_target']])
# Smallest valid average conformal set size at each checkpoint.
if not conf.empty:
    clean = conf[no_error_mask(conf)].copy()
    if len(clean) == 0:
        print('No successful conformal rows found. Empty CSV error fields may have been parsed unexpectedly.')
    else:
        clean['valid_at_target'] = clean['valid_at_target'].astype(bool)
        valid = clean[clean['valid_at_target']].copy()
        if len(valid) == 0:
            print('No rows reached target coverage. Inspect coverage plot first.')
        else:
            idx = valid.groupby('epoch')['avg_set_size'].idxmin()
            best_valid = valid.loc[idx].sort_values('epoch')
            display(best_valid[['epoch','head','test_acc','coverage','avg_set_size','singleton_rate','empty_rate','valid_at_target']])
# Display figures
fig_dir = out / 'figures'
for name in [
    'head_accuracy_by_epoch.png',
    'nonlinear_gain_by_epoch.png',
    'gain_vs_linear_accuracy.png',
    'geometry_diagnostics.png',
    'conformal_avg_set_size_by_epoch.png',
    'conformal_coverage_by_epoch.png',
    'conformal_singleton_rate_by_epoch.png',
    'conformal_selected_head_by_epoch.png',
    'conformal_set_size_vs_top_score_final.png',
]:
    p = fig_dir / name
    if p.exists():
        display(Markdown(f'### {name}'))
        display(Image(filename=str(p)))