Training Guide¶
This guide covers how to train antibody non-specificity prediction models using the pipeline.
Overview¶
Training involves:
- Configuration - Define datasets, model, classifier, and experiment parameters in YAML
- Embedding Extraction - Generate ESM-1v embeddings for training sequences
- Cross-Validation - Evaluate model performance via 10-fold stratified CV
- Final Training - Train on full training set
- Test Evaluation - Evaluate on hold-out test set
- Model Persistence - Save trained model in dual format (pickle + NPZ+JSON)
Quick Training Commands¶
Train with Default Config¶
Train with Hydra Overrides¶
# Override specific parameters
uv run antibody-train hardware.device=cuda training.batch_size=32
uv run antibody-train classifier.C=0.5 classifier.penalty=l1
Configuration with Hydra¶
Training is controlled via Hydra configuration in src/antibody_training_esm/conf/. The default config structure:
# src/antibody_training_esm/conf/config.yaml
model:
name: "facebook/esm1v_t33_650M_UR90S_1" # ESM-1v model from HuggingFace
revision: "main" # Model revision (for reproducibility)
data:
train_file: "data/train/boughter/canonical/VH_only_boughter_training.csv"
sequence_column: "sequence" # Column containing antibody sequences
label_column: "label" # Column containing binary labels (0=specific, 1=non-specific)
classifier:
type: "logistic_regression"
C: 1.0 # Regularization strength (inverse)
penalty: "l2" # Regularization type (l1, l2, elasticnet, none)
solver: "lbfgs" # Optimization algorithm
max_iter: 1000 # Maximum iterations
random_state: 42 # Seed for reproducibility
cv_folds: 10 # Number of cross-validation folds
training:
save_model: true # Save trained model to disk
model_name: "boughter_vh_esm1v_logreg"
model_save_dir: "./experiments/checkpoints"
batch_size: 8 # Embedding extraction batch size (default)
num_workers: 4
experiment:
name: "boughter_novo_reproduction"
hardware:
device: "auto" # Auto-detects CUDA > MPS > CPU
Note: With Hydra, you can override any parameter from CLI without editing files:
Training on Different Datasets¶
Boughter → Jain (Default)¶
Train on Boughter (914 VH, ELISA), test on Jain (86 clinical, ELISA):
Training config:
data:
train_file: "data/train/boughter/canonical/VH_only_boughter_training.csv"
sequence_column: "sequence"
label_column: "label"
Training:
Testing (after training):
uv run antibody-test \
--model experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl \
--data data/test/jain/fragments/VH_only_jain.csv
Expected Accuracy: ~68.60% (EXACT NOVO PARITY - matches Figure S14A)
Boughter → Harvey (Nanobodies)¶
Train on Boughter (914 VH, ELISA), test on Harvey (141k nanobodies, PSR):
Training config:
data:
train_file: "data/train/boughter/canonical/VH_only_boughter_training.csv"
sequence_column: "sequence"
label_column: "label"
Training:
Testing (after training):
uv run antibody-test \
--model experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl \
--data data/test/harvey/fragments/VHH_only_harvey.csv
Note: Cross-assay (ELISA → PSR) and cross-species (human antibodies → nanobodies) may reduce performance.
Boughter → Shehata (PSR Cross-Validation)¶
Train on Boughter (914 VH, ELISA), test on Shehata (398, PSR):
Training config:
data:
train_file: "data/train/boughter/canonical/VH_only_boughter_training.csv"
sequence_column: "sequence"
label_column: "label"
Training:
Testing (after training):
uv run antibody-test \
--model experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl \
--data data/test/shehata/fragments/VH_only_shehata.csv
Note: Cross-assay prediction (ELISA → PSR) requires assay-specific threshold tuning.
Training on Different Fragments¶
The pipeline supports training on various antibody sequence fragments by using different fragment CSV files.
How it works: Fragment files are pre-generated during preprocessing and stored in data/train/{dataset}/annotated/ or data/train/{dataset}/fragments/.
Example: Train on Heavy CDRs Only¶
# Override data file from CLI
uv run antibody-train data.train_file="data/train/boughter/annotated/H-CDRs_boughter.csv"
Available Fragments¶
Boughter (Training Set):
- VH_only_boughter_training.csv - Variable Heavy chain (default)
- H-CDR1_boughter.csv, H-CDR2_boughter.csv, H-CDR3_boughter.csv - Individual Heavy CDRs
- H-CDRs_boughter.csv - All Heavy CDRs concatenated
- H-FWRs_boughter.csv - Heavy Framework Regions
- (See data/train/boughter/annotated/ for all 16 fragments)
Fragment Naming Pattern:
- Format: {fragmentName}_{dataset}.csv
- Examples: VH_only_boughter.csv, All-CDRs_jain.csv, VHH_only_harvey.csv
Note: Fragment availability depends on dataset. See docs/datasets/{dataset}/ for preprocessing details and data/train/{dataset}/annotated/ for available files.
Hyperparameter Tuning¶
Regularization Strength (C)¶
Smaller C = stronger regularization (simpler model)
classifier:
C: 0.1 # Strong regularization (underfitting risk)
# OR
C: 1.0 # Default (balanced)
# OR
C: 10.0 # Weak regularization (overfitting risk)
Use cases:
- Small datasets: Use stronger regularization (C=0.1)
- Large datasets: Can use weaker regularization (C=10.0)
Regularization Type (penalty)¶
classifier:
penalty: "l2" # Ridge (default, works well for most cases)
# OR
penalty: "l1" # Lasso (feature selection, requires solver="liblinear")
# OR
penalty: "elasticnet" # Elastic Net (L1 + L2, requires solver="saga")
# OR
penalty: "none" # No regularization (overfitting risk)
Note: Penalty type must match solver:
- l2: Use
solver: "lbfgs"(default) - l1: Use
solver: "liblinear" - elasticnet: Use
solver: "saga"
Hyperparameter Sweep¶
For systematic hyperparameter search, see example sweep script:
Sweep strategy:
- Define parameter grid (e.g., C=[0.01, 0.1, 1.0, 10.0])
- Train model for each configuration
- Embeddings are cached (fast re-runs)
- Compare cross-validation metrics
Understanding Training Output¶
Cross-Validation Metrics¶
✅ 10-Fold Cross-Validation:
- Accuracy: 71.2% ± 3.5%
- Precision: 68.3% ± 4.2%
- Recall: 72.1% ± 5.1%
- F1 Score: 70.1% ± 3.8%
- ROC-AUC: 0.75 ± 0.04
Interpretation:
- Accuracy: Overall correct predictions (71.2% on Boughter)
- Precision: Of predicted non-specific, how many are truly non-specific
- Recall: Of truly non-specific, how many were detected
- F1 Score: Harmonic mean of precision and recall
- ROC-AUC: Area under ROC curve (0.5 = random, 1.0 = perfect)
Standard Deviation (±): Variability across folds (lower = more stable)
Test Set Metrics¶
Confusion Matrix:
Predicted
Neg Pos
Actual Neg [40 17] ← True Neg: 40, False Pos: 17
Pos [10 19] ← False Neg: 10, True Pos: 19
Performance:
- CV accuracy (71%) vs Test accuracy (69%) - Expected!
- Cross-dataset generalization is challenging (different assays, antibody sources)
- Novo Nordisk reported 68.6% - we achieve 68.60% (EXACT PARITY)
Training Best Practices¶
1. Start with Default Config¶
Use the validated default configuration as a baseline:
This ensures reproducibility and provides a reference point for comparisons.
2. Use Appropriate Test Sets¶
- Same-assay testing: Boughter → Jain (both ELISA)
- Cross-assay testing: Boughter → Harvey/Shehata (ELISA → PSR)
- Match fragment types: Train on VH, test on VH (not VL)
3. Monitor Overfitting¶
If cross-validation accuracy is high but test accuracy is low:
- Increase regularization: Decrease
C(e.g., 1.0 → 0.1) - Use L1 penalty: Feature selection via Lasso
- Simplify model: Consider simpler classifier
4. Leverage Embedding Caching¶
Embeddings are cached automatically:
experiments/cache/
└── {dataset}_{SHA256_hash}_embeddings.pkl # Pickle dict with embeddings + metadata
Benefits:
- Hyperparameter sweeps run 10-100x faster
- Cache invalidates automatically when data/model changes
- No manual cache management required
Note: First run downloads ESM-1v (~700 MB) and extracts embeddings (~5-10 min). Subsequent runs are instant.
5. Save All Experiments¶
Always enable model saving:
training:
save_model: true
model_name: "descriptive_experiment_name" # Use meaningful names
model_save_dir: "experiments/checkpoints/"
experiment:
name: "descriptive_experiment_name"
Dual-format model saving (automatic):
experiments/checkpoints/
└── {model_name}/
└── {classifier}/
├── {model_name}.pkl # Pickle (research/debugging)
├── {model_name}.npz # NumPy arrays (production weights)
└── {model_name}_config.json # Metadata (production config)
Why dual-format?
- Pickle (.pkl): Fast iteration, debugging, hyperparameter sweeps
- NPZ+JSON (.npz + _config.json): Production deployment, cross-platform, secure (no code execution)
Model Loading¶
Loading Models for Testing¶
Option 1: Pickle (research/debugging)
uv run antibody-test \
--model experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl \
--data data/test/jain/fragments/VH_only_jain.csv
Option 2: NPZ+JSON (production deployment)
from antibody_training_esm.core import load_model_from_npz
# Load model from production format
model = load_model_from_npz(
npz_path="experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.npz",
json_path="experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg_config.json"
)
# Use model for predictions
predictions = model.predict(X_test_embeddings)
When to use each format: - Pickle: Research workflows, local experiments, fast iteration - NPZ+JSON: Production APIs, HuggingFace deployment, cross-language loading
Advanced Training¶
Custom Dataset Paths¶
data:
train_file: "/absolute/path/to/training_data.csv"
sequence_column: "sequence" # Column name for sequences
label_column: "label" # Column name for binary labels
CSV Format Requirements:
- Must have
sequencecolumn (antibody amino acid sequence) - Must have
labelcolumn (0=specific, 1=non-specific) - Column names can be customized via
sequence_columnandlabel_columnconfig keys
Custom ESM Model¶
Use different ESM model versions via Hydra config groups:
# ESM-1v (default, Novo Nordisk validated)
uv run antibody-train model=esm1v
# ESM2-650M (supported, comparable performance)
uv run antibody-train model=esm2_650m
Or override the model name directly:
Available ESM models:
| Model | Config Group | Performance | Status |
|---|---|---|---|
| ESM-1v (650M) | model=esm1v |
66.3% Jain accuracy (Novo baseline) | ✅ Default |
| ESM2-650M | model=esm2_650m |
~64-68% Jain (predicted) | ✅ Supported |
| ESM2-3B | N/A (manual override) | Higher (requires 24+ GB GPU) | 📋 Planned |
Full model names (for manual override):
- facebook/esm1v_t33_650M_UR90S_1 - ESM-1v (default, validated)
- facebook/esm2_t33_650M_UR50D - ESM2-650M (supported)
- facebook/esm2_t36_3B_UR50D - ESM2-3B (requires 24+ GB GPU)
GPU Memory Management¶
Reduce batch size if encountering OOM errors:
training:
batch_size: 8 # Default; lower if needed (e.g., 4)
hardware:
device: "cuda" # or "mps" for Apple Silicon
Memory Requirements:
| Batch Size | GPU Memory | Speed |
|---|---|---|
| 4 | 4 GB | Slow |
| 8 | 8 GB | Medium (default) |
| 16 | 12 GB | Fast |
| 32 | 24 GB | Fastest |
Troubleshooting¶
Training Fails with "Label column not found"¶
Solution: Ensure CSV has label column with 0/1 values:
Embeddings Cache Out of Sync¶
Solution: Clear cache and retrain:
Poor Test Performance¶
Possible causes:
- Cross-assay mismatch: Train ELISA, test PSR → tune threshold
- Cross-species mismatch: Train human, test nanobodies → expect lower accuracy
- Overfitting: High CV accuracy, low test accuracy → increase regularization
- Underfitting: Low CV and test accuracy → decrease regularization
See Troubleshooting Guide for more solutions.
Next Steps¶
- Testing Models: See Testing Guide for evaluating trained models
- Preprocessing Data: See Preprocessing Guide for preparing new datasets
- Hyperparameter Sweeps: See reference script in
preprocessing/boughter/train_hyperparameter_sweep.py
Last Updated: 2025-11-18
Branch: dev