Trainer
trainer
¶
Training Module
Professional training pipeline for antibody classification models. Includes cross-validation, embedding caching, and comprehensive evaluation.
Classes¶
Functions¶
get_or_create_embeddings(sequences, embedding_extractor, cache_path, dataset_name, logger)
¶
Get embeddings from cache or create them
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sequences
|
list[str]
|
List of protein sequences |
required |
embedding_extractor
|
EmbeddingExtractorProtocol
|
ESM or AMPLIFY embedding extractor |
required |
cache_path
|
str | Path
|
Directory for caching embeddings |
required |
dataset_name
|
str
|
Name of dataset (for cache filename) |
required |
logger
|
Logger
|
Logger instance |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Array of embeddings |
Raises:
| Type | Description |
|---|---|
ValueError
|
If cached or computed embeddings are invalid |
Source code in src/antibody_training_esm/core/training/cache.py
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | |
validate_embeddings(embeddings, num_sequences, logger, source='cache')
¶
Validate embeddings are not corrupted.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
ndarray
|
Embedding array to validate |
required |
num_sequences
|
int
|
Expected number of sequences |
required |
logger
|
Logger
|
Logger instance |
required |
source
|
str
|
Where embeddings came from (for error messages) |
'cache'
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If embeddings are invalid (wrong shape, NaN, all zeros) |
Source code in src/antibody_training_esm/core/training/cache.py
evaluate_model(classifier, X, y, dataset_name, _metrics, logger)
¶
Evaluate model performance
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
classifier
|
BinaryClassifier
|
Trained classifier |
required |
X
|
ndarray
|
Embeddings array |
required |
y
|
ndarray
|
Labels array |
required |
dataset_name
|
str
|
Name of dataset being evaluated |
required |
_metrics
|
Sequence[str] | set[str]
|
List/Set of metrics to compute (ignored, computes all standard metrics) |
required |
logger
|
Logger
|
Logger instance |
required |
Returns:
| Type | Description |
|---|---|
EvaluationMetrics
|
EvaluationMetrics Pydantic model |
Source code in src/antibody_training_esm/core/training/metrics.py
perform_cross_validation(X, y, config, logger)
¶
Perform cross-validation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
ndarray
|
Embeddings array |
required |
y
|
ndarray
|
Labels array |
required |
config
|
TrainingPipelineConfig | dict[str, Any]
|
Configuration (Pydantic object or legacy dict) |
required |
logger
|
Logger
|
Logger instance |
required |
Returns:
| Type | Description |
|---|---|
CVResults
|
CVResults Pydantic model |
Source code in src/antibody_training_esm/core/training/metrics.py
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | |
save_cv_results(cv_results, output_dir, experiment_name, logger)
¶
Save cross-validation results to structured YAML file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cv_results
|
CVResults
|
CVResults Pydantic model |
required |
output_dir
|
Path
|
Directory to save CV results file |
required |
experiment_name
|
str
|
Name of the experiment |
required |
logger
|
Logger
|
Logger instance |
required |
Source code in src/antibody_training_esm/core/training/metrics.py
load_config(config_path)
¶
Load configuration from YAML file
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config_path
|
str
|
Path to YAML configuration file |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Configuration dictionary |
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If config file doesn't exist |
ValueError
|
If YAML is invalid |
Source code in src/antibody_training_esm/core/training/serialization.py
load_model_from_npz(npz_path, json_path)
¶
Load model from NPZ+JSON format (production deployment)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
npz_path
|
str
|
Path to .npz file with arrays |
required |
json_path
|
str
|
Path to .json file with metadata |
required |
Returns:
| Type | Description |
|---|---|
BinaryClassifier
|
Reconstructed BinaryClassifier instance |
Notes
This function enables production deployment without pickle files. It reconstructs a fully functional BinaryClassifier from NPZ+JSON format. Uses strict Pydantic validation for metadata.
Source code in src/antibody_training_esm/core/training/serialization.py
save_model(classifier, config, logger)
¶
Save trained model in dual format (pickle + NPZ+JSON)
Models are saved in hierarchical directory structure
{model_save_dir}/{model_shortname}/{classifier_type}/{model_name}.*
Example
experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
classifier
|
BinaryClassifier
|
Trained classifier |
required |
config
|
TrainingPipelineConfig | dict[str, Any]
|
Configuration dictionary or Pydantic model |
required |
logger
|
Logger
|
Logger instance |
required |
Returns:
| Type | Description |
|---|---|
dict[str, str]
|
Dictionary with paths to saved files: |
dict[str, str]
|
{ "pickle": "experiments/checkpoints/esm1v/logreg/model.pkl", "npz": "experiments/checkpoints/esm1v/logreg/model.npz", "config": "experiments/checkpoints/esm1v/logreg/model_config.json" |
dict[str, str]
|
} |
dict[str, str]
|
Empty dict if saving is disabled. |
Source code in src/antibody_training_esm/core/training/serialization.py
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | |
validate_config(config)
¶
Validate config with Pydantic models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict[str, Any] | DictConfig
|
Raw dict or Hydra DictConfig |
required |
Returns:
| Type | Description |
|---|---|
TrainingPipelineConfig
|
Validated TrainingPipelineConfig |
Raises:
| Type | Description |
|---|---|
ValidationError
|
If config is invalid |
Source code in src/antibody_training_esm/core/trainer.py
setup_logging(config)
¶
Setup logging from Pydantic config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TrainingPipelineConfig
|
Validated TrainingPipelineConfig |
required |
Returns:
| Type | Description |
|---|---|
Logger
|
Configured logger |
Source code in src/antibody_training_esm/core/trainer.py
train_pipeline(cfg)
¶
Core training pipeline with Pydantic validation.
Source code in src/antibody_training_esm/core/trainer.py
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 | |
main(cfg)
¶
Hydra entry point for CLI - DO NOT call directly in tests
This is the CLI entry point decorated with @hydra.main. It: - Automatically parses command-line overrides - Creates Hydra output directories - Saves composed config to .hydra/config.yaml - Delegates to train_pipeline() for core logic
Usage
Default config¶
python -m antibody_training_esm.core.trainer
With overrides¶
python -m antibody_training_esm.core.trainer model.batch_size=16
Multi-run sweep¶
python -m antibody_training_esm.core.trainer --multirun model=esm1v,esm2
Note
Tests should call train_pipeline() directly, not this function. This function is only for CLI usage with sys.argv parsing.