Skip to content

Training Guide

This guide covers how to train antibody non-specificity prediction models using the pipeline.


Overview

Training involves:

  1. Configuration - Define datasets, model, classifier, and experiment parameters in YAML
  2. Embedding Extraction - Generate ESM-1v embeddings for training sequences
  3. Cross-Validation - Evaluate model performance via 10-fold stratified CV
  4. Final Training - Train on full training set
  5. Test Evaluation - Evaluate on hold-out test set
  6. Model Persistence - Save trained model in dual format (pickle + NPZ+JSON)

Quick Training Commands

Train with Default Config

make train
# OR
uv run antibody-train

Train with Hydra Overrides

# Override specific parameters
uv run antibody-train hardware.device=cuda training.batch_size=32
uv run antibody-train classifier.C=0.5 classifier.penalty=l1

Configuration with Hydra

Training is controlled via Hydra configuration in src/antibody_training_esm/conf/. The default config structure:

# src/antibody_training_esm/conf/config.yaml

model:
  name: "facebook/esm1v_t33_650M_UR90S_1"  # ESM-1v model from HuggingFace
  revision: "main"                         # Model revision (for reproducibility)

data:
  train_file: "data/train/boughter/canonical/VH_only_boughter_training.csv"
  sequence_column: "sequence"  # Column containing antibody sequences
  label_column: "label"        # Column containing binary labels (0=specific, 1=non-specific)

classifier:
  type: "logistic_regression"
  C: 1.0                      # Regularization strength (inverse)
  penalty: "l2"               # Regularization type (l1, l2, elasticnet, none)
  solver: "lbfgs"             # Optimization algorithm
  max_iter: 1000              # Maximum iterations
  random_state: 42            # Seed for reproducibility
  cv_folds: 10                # Number of cross-validation folds

training:
  save_model: true            # Save trained model to disk
  model_name: "boughter_vh_esm1v_logreg"
  model_save_dir: "./experiments/checkpoints"
  batch_size: 8               # Embedding extraction batch size (default)
  num_workers: 4

experiment:
  name: "boughter_novo_reproduction"

hardware:
  device: "auto"              # Auto-detects CUDA > MPS > CPU

Note: With Hydra, you can override any parameter from CLI without editing files:

uv run antibody-train hardware.device=cuda classifier.C=0.5

Training on Different Datasets

Boughter → Jain (Default)

Train on Boughter (914 VH, ELISA), test on Jain (86 clinical, ELISA):

Training config:

data:
  train_file: "data/train/boughter/canonical/VH_only_boughter_training.csv"
  sequence_column: "sequence"
  label_column: "label"

Training:

uv run antibody-train

Testing (after training):

uv run antibody-test \
  --model experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl \
  --data data/test/jain/fragments/VH_only_jain.csv

Expected Accuracy: ~68.60% (EXACT NOVO PARITY - matches Figure S14A)


Boughter → Harvey (Nanobodies)

Train on Boughter (914 VH, ELISA), test on Harvey (141k nanobodies, PSR):

Training config:

data:
  train_file: "data/train/boughter/canonical/VH_only_boughter_training.csv"
  sequence_column: "sequence"
  label_column: "label"

Training:

uv run antibody-train

Testing (after training):

uv run antibody-test \
  --model experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl \
  --data data/test/harvey/fragments/VHH_only_harvey.csv

Note: Cross-assay (ELISA → PSR) and cross-species (human antibodies → nanobodies) may reduce performance.


Boughter → Shehata (PSR Cross-Validation)

Train on Boughter (914 VH, ELISA), test on Shehata (398, PSR):

Training config:

data:
  train_file: "data/train/boughter/canonical/VH_only_boughter_training.csv"
  sequence_column: "sequence"
  label_column: "label"

Training:

uv run antibody-train

Testing (after training):

uv run antibody-test \
  --model experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl \
  --data data/test/shehata/fragments/VH_only_shehata.csv

Note: Cross-assay prediction (ELISA → PSR) requires assay-specific threshold tuning.


Training on Different Fragments

The pipeline supports training on various antibody sequence fragments by using different fragment CSV files.

How it works: Fragment files are pre-generated during preprocessing and stored in data/train/{dataset}/annotated/ or data/train/{dataset}/fragments/.

Example: Train on Heavy CDRs Only

# Override data file from CLI
uv run antibody-train data.train_file="data/train/boughter/annotated/H-CDRs_boughter.csv"

Available Fragments

Boughter (Training Set): - VH_only_boughter_training.csv - Variable Heavy chain (default) - H-CDR1_boughter.csv, H-CDR2_boughter.csv, H-CDR3_boughter.csv - Individual Heavy CDRs - H-CDRs_boughter.csv - All Heavy CDRs concatenated - H-FWRs_boughter.csv - Heavy Framework Regions - (See data/train/boughter/annotated/ for all 16 fragments)

Fragment Naming Pattern: - Format: {fragmentName}_{dataset}.csv - Examples: VH_only_boughter.csv, All-CDRs_jain.csv, VHH_only_harvey.csv

Note: Fragment availability depends on dataset. See docs/datasets/{dataset}/ for preprocessing details and data/train/{dataset}/annotated/ for available files.


Hyperparameter Tuning

Regularization Strength (C)

Smaller C = stronger regularization (simpler model)

classifier:
  C: 0.1   # Strong regularization (underfitting risk)
  # OR
  C: 1.0   # Default (balanced)
  # OR
  C: 10.0  # Weak regularization (overfitting risk)

Use cases:

  • Small datasets: Use stronger regularization (C=0.1)
  • Large datasets: Can use weaker regularization (C=10.0)

Regularization Type (penalty)

classifier:
  penalty: "l2"          # Ridge (default, works well for most cases)
  # OR
  penalty: "l1"          # Lasso (feature selection, requires solver="liblinear")
  # OR
  penalty: "elasticnet"  # Elastic Net (L1 + L2, requires solver="saga")
  # OR
  penalty: "none"        # No regularization (overfitting risk)

Note: Penalty type must match solver:

  • l2: Use solver: "lbfgs" (default)
  • l1: Use solver: "liblinear"
  • elasticnet: Use solver: "saga"

Hyperparameter Sweep

For systematic hyperparameter search, see example sweep script:

# See reference implementation
cat preprocessing/boughter/train_hyperparameter_sweep.py

Sweep strategy:

  1. Define parameter grid (e.g., C=[0.01, 0.1, 1.0, 10.0])
  2. Train model for each configuration
  3. Embeddings are cached (fast re-runs)
  4. Compare cross-validation metrics

Understanding Training Output

Cross-Validation Metrics

✅ 10-Fold Cross-Validation:
   - Accuracy: 71.2% ± 3.5%
   - Precision: 68.3% ± 4.2%
   - Recall: 72.1% ± 5.1%
   - F1 Score: 70.1% ± 3.8%
   - ROC-AUC: 0.75 ± 0.04

Interpretation:

  • Accuracy: Overall correct predictions (71.2% on Boughter)
  • Precision: Of predicted non-specific, how many are truly non-specific
  • Recall: Of truly non-specific, how many were detected
  • F1 Score: Harmonic mean of precision and recall
  • ROC-AUC: Area under ROC curve (0.5 = random, 1.0 = perfect)

Standard Deviation (±): Variability across folds (lower = more stable)


Test Set Metrics

✅ Test Set (Jain):
   - Accuracy: 68.60%
   - Confusion Matrix: [[40, 17], [10, 19]]

Confusion Matrix:

                 Predicted
                 Neg    Pos
Actual  Neg     [40     17]   ← True Neg: 40, False Pos: 17
        Pos     [10     19]   ← False Neg: 10, True Pos: 19

Performance:

  • CV accuracy (71%) vs Test accuracy (69%) - Expected!
  • Cross-dataset generalization is challenging (different assays, antibody sources)
  • Novo Nordisk reported 68.6% - we achieve 68.60% (EXACT PARITY)

Training Best Practices

1. Start with Default Config

Use the validated default configuration as a baseline:

make train

This ensures reproducibility and provides a reference point for comparisons.


2. Use Appropriate Test Sets

  • Same-assay testing: Boughter → Jain (both ELISA)
  • Cross-assay testing: Boughter → Harvey/Shehata (ELISA → PSR)
  • Match fragment types: Train on VH, test on VH (not VL)

3. Monitor Overfitting

If cross-validation accuracy is high but test accuracy is low:

  • Increase regularization: Decrease C (e.g., 1.0 → 0.1)
  • Use L1 penalty: Feature selection via Lasso
  • Simplify model: Consider simpler classifier

4. Leverage Embedding Caching

Embeddings are cached automatically:

experiments/cache/
└── {dataset}_{SHA256_hash}_embeddings.pkl  # Pickle dict with embeddings + metadata

Benefits:

  • Hyperparameter sweeps run 10-100x faster
  • Cache invalidates automatically when data/model changes
  • No manual cache management required

Note: First run downloads ESM-1v (~700 MB) and extracts embeddings (~5-10 min). Subsequent runs are instant.


5. Save All Experiments

Always enable model saving:

training:
  save_model: true
  model_name: "descriptive_experiment_name"  # Use meaningful names
  model_save_dir: "experiments/checkpoints/"

experiment:
  name: "descriptive_experiment_name"

Dual-format model saving (automatic):

experiments/checkpoints/
└── {model_name}/
    └── {classifier}/
        ├── {model_name}.pkl          # Pickle (research/debugging)
        ├── {model_name}.npz          # NumPy arrays (production weights)
        └── {model_name}_config.json  # Metadata (production config)

Why dual-format? - Pickle (.pkl): Fast iteration, debugging, hyperparameter sweeps - NPZ+JSON (.npz + _config.json): Production deployment, cross-platform, secure (no code execution)


Model Loading

Loading Models for Testing

Option 1: Pickle (research/debugging)

uv run antibody-test \
  --model experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl \
  --data data/test/jain/fragments/VH_only_jain.csv

Option 2: NPZ+JSON (production deployment)

from antibody_training_esm.core import load_model_from_npz

# Load model from production format
model = load_model_from_npz(
    npz_path="experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.npz",
    json_path="experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg_config.json"
)

# Use model for predictions
predictions = model.predict(X_test_embeddings)

When to use each format: - Pickle: Research workflows, local experiments, fast iteration - NPZ+JSON: Production APIs, HuggingFace deployment, cross-language loading


Advanced Training

Custom Dataset Paths

data:
  train_file: "/absolute/path/to/training_data.csv"
  sequence_column: "sequence"  # Column name for sequences
  label_column: "label"        # Column name for binary labels

CSV Format Requirements:

  • Must have sequence column (antibody amino acid sequence)
  • Must have label column (0=specific, 1=non-specific)
  • Column names can be customized via sequence_column and label_column config keys

Custom ESM Model

Use different ESM model versions via Hydra config groups:

# ESM-1v (default, Novo Nordisk validated)
uv run antibody-train model=esm1v

# ESM2-650M (supported, comparable performance)
uv run antibody-train model=esm2_650m

Or override the model name directly:

uv run antibody-train model.name=facebook/esm2_t33_650M_UR50D

Available ESM models:

Model Config Group Performance Status
ESM-1v (650M) model=esm1v 66.3% Jain accuracy (Novo baseline) ✅ Default
ESM2-650M model=esm2_650m ~64-68% Jain (predicted) ✅ Supported
ESM2-3B N/A (manual override) Higher (requires 24+ GB GPU) 📋 Planned

Full model names (for manual override): - facebook/esm1v_t33_650M_UR90S_1 - ESM-1v (default, validated) - facebook/esm2_t33_650M_UR50D - ESM2-650M (supported) - facebook/esm2_t36_3B_UR50D - ESM2-3B (requires 24+ GB GPU)


GPU Memory Management

Reduce batch size if encountering OOM errors:

training:
  batch_size: 8   # Default; lower if needed (e.g., 4)
hardware:
  device: "cuda"  # or "mps" for Apple Silicon

Memory Requirements:

Batch Size GPU Memory Speed
4 4 GB Slow
8 8 GB Medium (default)
16 12 GB Fast
32 24 GB Fastest

Troubleshooting

Training Fails with "Label column not found"

Solution: Ensure CSV has label column with 0/1 values:

sequence,label
EVQLVESGGGLV...,0
QVQLQESGPGLV...,1

Embeddings Cache Out of Sync

Solution: Clear cache and retrain:

rm -rf experiments/cache/
uv run antibody-train

Poor Test Performance

Possible causes:

  1. Cross-assay mismatch: Train ELISA, test PSR → tune threshold
  2. Cross-species mismatch: Train human, test nanobodies → expect lower accuracy
  3. Overfitting: High CV accuracy, low test accuracy → increase regularization
  4. Underfitting: Low CV and test accuracy → decrease regularization

See Troubleshooting Guide for more solutions.


Next Steps

  • Testing Models: See Testing Guide for evaluating trained models
  • Preprocessing Data: See Preprocessing Guide for preparing new datasets
  • Hyperparameter Sweeps: See reference script in preprocessing/boughter/train_hyperparameter_sweep.py

Last Updated: 2025-11-18 Branch: dev