Skip to content

Trainer

trainer

Training Module

Professional training pipeline for antibody classification models. Includes cross-validation, embedding caching, and comprehensive evaluation.

Classes

Functions

get_or_create_embeddings(sequences, embedding_extractor, cache_path, dataset_name, logger)

Get embeddings from cache or create them

Parameters:

Name Type Description Default
sequences list[str]

List of protein sequences

required
embedding_extractor EmbeddingExtractorProtocol

ESM or AMPLIFY embedding extractor

required
cache_path str | Path

Directory for caching embeddings

required
dataset_name str

Name of dataset (for cache filename)

required
logger Logger

Logger instance

required

Returns:

Type Description
ndarray

Array of embeddings

Raises:

Type Description
ValueError

If cached or computed embeddings are invalid

Source code in src/antibody_training_esm/core/training/cache.py
def get_or_create_embeddings(
    sequences: list[str],
    embedding_extractor: "EmbeddingExtractorProtocol",
    cache_path: str | Path,
    dataset_name: str,
    logger: logging.Logger,
) -> np.ndarray:
    """
    Get embeddings from cache or create them

    Args:
        sequences: List of protein sequences
        embedding_extractor: ESM or AMPLIFY embedding extractor
        cache_path: Directory for caching embeddings
        dataset_name: Name of dataset (for cache filename)
        logger: Logger instance

    Returns:
        Array of embeddings

    Raises:
        ValueError: If cached or computed embeddings are invalid
    """
    # Ensure cache_path is string for os.path.join/os.makedirs compatibility
    # (os.path supports Path in 3.6+, but for safety/consistency with type hint)
    cache_path_str = str(cache_path)

    # P2.3 fix: Use streaming hash to avoid creating giant string in memory
    # For large datasets (100k+ sequences), joining all sequences into one
    # string can consume 10s of MB and cause memory pressure.
    hasher = hashlib.sha256()

    # Hash model metadata first
    hasher.update(
        f"{embedding_extractor.model_name}|"
        f"{embedding_extractor.revision}|"
        f"{embedding_extractor.max_length}|".encode()
    )

    # Stream sequences through hash (no giant string!)
    for seq in sequences:
        hasher.update(seq.encode())
        hasher.update(b"|")  # Separator

    sequences_hash = hasher.hexdigest()[:12]
    cache_file = os.path.join(
        cache_path_str, f"{dataset_name}_{sequences_hash}_embeddings.pkl"
    )

    if os.path.exists(cache_file):
        logger.info(f"Loading cached embeddings from {cache_file}")
        with open(cache_file, "rb") as f:
            cached_data_raw = pickle.load(f)  # nosec B301 - Hash-validated local cache

        # Validate loaded data type and structure
        if not isinstance(cached_data_raw, dict):
            logger.warning(
                f"Invalid cache file format (expected dict, got {type(cached_data_raw).__name__}). "
                "Recomputing embeddings..."
            )
        elif (
            "embeddings" not in cached_data_raw
            or "sequences_hash" not in cached_data_raw
        ):
            missing_keys = {"embeddings", "sequences_hash"} - set(
                cached_data_raw.keys()
            )
            logger.warning(
                f"Corrupt cache file (missing keys: {missing_keys}). "
                "Recomputing embeddings..."
            )
        else:
            cached_data: dict[str, Any] = cached_data_raw

            # Verify the cached sequences and model metadata match exactly
            # This prevents ESM2 from reusing ESM-1v embeddings, etc.
            model_metadata_matches = (
                cached_data.get("model_name") == embedding_extractor.model_name
                and cached_data.get("revision") == embedding_extractor.revision
                and cached_data.get("max_length") == embedding_extractor.max_length
            )

            if (
                len(cached_data["embeddings"]) == len(sequences)
                and cached_data["sequences_hash"] == sequences_hash
                and model_metadata_matches
            ):
                logger.info(
                    f"Using cached embeddings for {len(sequences)} sequences "
                    f"(model: {embedding_extractor.model_name}, hash: {sequences_hash})"
                )
                embeddings_result: np.ndarray = cached_data["embeddings"]

                # Validate cached embeddings before using them
                validate_embeddings(
                    embeddings_result, len(sequences), logger, source="cache"
                )

                return embeddings_result
            elif not model_metadata_matches:
                logger.warning(
                    f"Cached embeddings model mismatch "
                    f"(cached: {cached_data.get('model_name')}, "
                    f"current: {embedding_extractor.model_name}). "
                    "Recomputing..."
                )
            else:
                logger.warning("Cached embeddings hash mismatch, recomputing...")

    logger.info(f"Computing embeddings for {len(sequences)} sequences...")
    embeddings = embedding_extractor.extract_batch_embeddings(sequences)

    # Validate newly computed embeddings before caching
    validate_embeddings(embeddings, len(sequences), logger, source="computed")

    # Cache the embeddings with metadata for verification
    # Include model metadata to prevent cache collisions between different backbones
    os.makedirs(cache_path_str, exist_ok=True)
    cache_data = {
        "embeddings": embeddings,
        "sequences_hash": sequences_hash,
        "num_sequences": len(sequences),
        "dataset_name": dataset_name,
        "model_name": embedding_extractor.model_name,
        "revision": embedding_extractor.revision,
        "max_length": embedding_extractor.max_length,
    }
    with open(cache_file, "wb") as f:
        pickle.dump(cache_data, f)
    logger.info(
        f"Cached embeddings to {cache_file} "
        f"(model: {embedding_extractor.model_name}, hash: {sequences_hash})"
    )

    return embeddings

validate_embeddings(embeddings, num_sequences, logger, source='cache')

Validate embeddings are not corrupted.

Parameters:

Name Type Description Default
embeddings ndarray

Embedding array to validate

required
num_sequences int

Expected number of sequences

required
logger Logger

Logger instance

required
source str

Where embeddings came from (for error messages)

'cache'

Raises:

Type Description
ValueError

If embeddings are invalid (wrong shape, NaN, all zeros)

Source code in src/antibody_training_esm/core/training/cache.py
def validate_embeddings(
    embeddings: np.ndarray,
    num_sequences: int,
    logger: logging.Logger,
    source: str = "cache",
) -> None:
    """
    Validate embeddings are not corrupted.

    Args:
        embeddings: Embedding array to validate
        num_sequences: Expected number of sequences
        logger: Logger instance
        source: Where embeddings came from (for error messages)

    Raises:
        ValueError: If embeddings are invalid (wrong shape, NaN, all zeros)
    """
    # Check shape
    if embeddings.shape[0] != num_sequences:
        raise ValueError(
            f"Embeddings from {source} have wrong shape: expected {num_sequences} sequences, "
            f"got {embeddings.shape[0]}"
        )

    if len(embeddings.shape) != 2:
        raise ValueError(
            f"Embeddings from {source} must be 2D array, got shape {embeddings.shape}"
        )

    # Check for NaN values
    if np.isnan(embeddings).any():
        nan_count = np.isnan(embeddings).sum()
        raise ValueError(
            f"Embeddings from {source} contain {nan_count} NaN values. "
            "This indicates corrupted embeddings - cannot train on invalid data."
        )

    # Check for all-zero rows (corrupted/failed embeddings)
    zero_rows = np.all(embeddings == 0, axis=1)
    if zero_rows.any():
        zero_count = zero_rows.sum()
        raise ValueError(
            f"Embeddings from {source} contain {zero_count} all-zero rows. "
            "This indicates corrupted embeddings from failed batch processing. "
            "Delete the cache file and recompute."
        )

    logger.debug(
        f"Embeddings validation passed: shape={embeddings.shape}, no NaN, no zero rows"
    )

evaluate_model(classifier, X, y, dataset_name, _metrics, logger)

Evaluate model performance

Parameters:

Name Type Description Default
classifier BinaryClassifier

Trained classifier

required
X ndarray

Embeddings array

required
y ndarray

Labels array

required
dataset_name str

Name of dataset being evaluated

required
_metrics Sequence[str] | set[str]

List/Set of metrics to compute (ignored, computes all standard metrics)

required
logger Logger

Logger instance

required

Returns:

Type Description
EvaluationMetrics

EvaluationMetrics Pydantic model

Source code in src/antibody_training_esm/core/training/metrics.py
def evaluate_model(
    classifier: BinaryClassifier,
    X: np.ndarray,
    y: np.ndarray,
    dataset_name: str,
    _metrics: Sequence[str] | set[str],
    logger: logging.Logger,
) -> EvaluationMetrics:
    """
    Evaluate model performance

    Args:
        classifier: Trained classifier
        X: Embeddings array
        y: Labels array
        dataset_name: Name of dataset being evaluated
        _metrics: List/Set of metrics to compute (ignored, computes all standard metrics)
        logger: Logger instance

    Returns:
        EvaluationMetrics Pydantic model
    """
    logger.info(f"Evaluating model on {dataset_name} set")

    # Get predictions
    y_pred = classifier.predict(X)
    y_pred_proba = classifier.predict_proba(X)

    # Create metrics object using Pydantic factory
    eval_metrics = EvaluationMetrics.from_sklearn_metrics(
        y_true=y,
        y_pred=y_pred,
        y_proba=y_pred_proba,
        dataset_name=dataset_name,
    )

    # Log results
    logger.info(f"{dataset_name} Results:")
    logger.info(f"  Accuracy:  {eval_metrics.accuracy:.4f}")
    if eval_metrics.precision is not None:
        logger.info(f"  Precision: {eval_metrics.precision:.4f}")
    if eval_metrics.recall is not None:
        logger.info(f"  Recall:    {eval_metrics.recall:.4f}")
    if eval_metrics.f1 is not None:
        logger.info(f"  F1:        {eval_metrics.f1:.4f}")
    if eval_metrics.roc_auc is not None:
        logger.info(f"  ROC-AUC:   {eval_metrics.roc_auc:.4f}")

    # Log classification report (useful for detailed class-wise metrics)
    logger.info(f"\n{dataset_name} Classification Report:")
    logger.info(f"\n{classification_report(y, y_pred)}")

    return eval_metrics

perform_cross_validation(X, y, config, logger)

Perform cross-validation

Parameters:

Name Type Description Default
X ndarray

Embeddings array

required
y ndarray

Labels array

required
config TrainingPipelineConfig | dict[str, Any]

Configuration (Pydantic object or legacy dict)

required
logger Logger

Logger instance

required

Returns:

Type Description
CVResults

CVResults Pydantic model

Source code in src/antibody_training_esm/core/training/metrics.py
def perform_cross_validation(
    X: np.ndarray,
    y: np.ndarray,
    config: "TrainingPipelineConfig | dict[str, Any]",
    logger: logging.Logger,
) -> CVResults:
    """
    Perform cross-validation

    Args:
        X: Embeddings array
        y: Labels array
        config: Configuration (Pydantic object or legacy dict)
        logger: Logger instance

    Returns:
        CVResults Pydantic model
    """
    from antibody_training_esm.models.config import TrainingPipelineConfig

    # Extract parameters based on config type
    if isinstance(config, TrainingPipelineConfig):
        cv_folds = config.training.n_splits
        random_state = config.training.random_state
        stratify = config.training.stratify
        model_name = config.model.name
        device = resolve_device(config.model.device)
        batch_size = config.model.batch_size
        model_type = config.model.model_type

        clf_params = config.classifier.model_dump()
    else:
        training_conf = config.get("training", {})
        classifier_conf = config.get("classifier", {})

        cv_folds = training_conf.get("n_splits", classifier_conf.get("cv_folds", 10))
        stratify = training_conf.get("stratify", True)
        random_state = training_conf.get(
            "random_state", classifier_conf.get("random_state", 42)
        )

        model_cfg = config.get("model", {})
        model_name = model_cfg.get("name", "")
        device = resolve_device(model_cfg.get("device", "cpu"))
        batch_size = training_conf.get(
            "batch_size", model_cfg.get("batch_size", DEFAULT_BATCH_SIZE)
        )
        model_type = model_cfg.get("model_type", "esm")
        clf_params = classifier_conf.copy()

    logger.info(f"Performing {cv_folds}-fold cross-validation")

    # Setup cross-validation
    if stratify:
        cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
    else:
        cv = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)

    # Create a new classifier instance for CV
    cv_params = clf_params.copy()
    cv_params["model_name"] = model_name
    cv_params["device"] = device
    cv_params["batch_size"] = batch_size
    cv_params["model_type"] = model_type

    cv_classifier = BinaryClassifier(cv_params)

    # Define metrics to compute
    scoring = {
        "accuracy": "accuracy",
        "f1": "f1",
        "precision": "precision",
        "recall": "recall",
        "roc_auc": "roc_auc",
    }

    # Perform cross-validation using cross_validate (more efficient than multiple cross_val_score calls)
    cv_scores = cross_validate(
        cv_classifier, X, y, cv=cv, scoring=scoring, return_train_score=False
    )

    # Create CVResults object using Pydantic factory
    cv_results = CVResults.from_sklearn_cv_results(cv_scores, n_splits=cv_folds)

    # Log results
    logger.info("Cross-validation Results:")
    logger.info(
        f"  Accuracy: {cv_results.cv_accuracy['mean']:.4f} (+/- {cv_results.cv_accuracy['std'] * 2:.4f})"
    )
    if cv_results.cv_f1:
        logger.info(
            f"  F1:       {cv_results.cv_f1['mean']:.4f} (+/- {cv_results.cv_f1['std'] * 2:.4f})"
        )
    if cv_results.cv_roc_auc:
        logger.info(
            f"  ROC-AUC:  {cv_results.cv_roc_auc['mean']:.4f} (+/- {cv_results.cv_roc_auc['std'] * 2:.4f})"
        )

    return cv_results

save_cv_results(cv_results, output_dir, experiment_name, logger)

Save cross-validation results to structured YAML file.

Parameters:

Name Type Description Default
cv_results CVResults

CVResults Pydantic model

required
output_dir Path

Directory to save CV results file

required
experiment_name str

Name of the experiment

required
logger Logger

Logger instance

required
Source code in src/antibody_training_esm/core/training/metrics.py
def save_cv_results(
    cv_results: CVResults,
    output_dir: Path,
    experiment_name: str,
    logger: logging.Logger,
) -> None:
    """
    Save cross-validation results to structured YAML file.

    Args:
        cv_results: CVResults Pydantic model
        output_dir: Directory to save CV results file
        experiment_name: Name of the experiment
        logger: Logger instance
    """
    # Ensure output directory exists
    output_dir.mkdir(parents=True, exist_ok=True)

    cv_file = output_dir / "cv_results.yaml"

    # Use Pydantic's model_dump for clean serialization
    results_dict = cv_results.model_dump(mode="json")

    with open(cv_file, "w") as f:
        yaml.dump(
            {
                "experiment": experiment_name,
                "timestamp": datetime.now().isoformat(),
                "cv_metrics": results_dict,
            },
            f,
            default_flow_style=False,
        )

    logger.info(f"CV results saved to {cv_file}")

load_config(config_path)

Load configuration from YAML file

Parameters:

Name Type Description Default
config_path str

Path to YAML configuration file

required

Returns:

Type Description
dict[str, Any]

Configuration dictionary

Raises:

Type Description
FileNotFoundError

If config file doesn't exist

ValueError

If YAML is invalid

Source code in src/antibody_training_esm/core/training/serialization.py
def load_config(config_path: str) -> dict[str, Any]:
    """
    Load configuration from YAML file

    Args:
        config_path: Path to YAML configuration file

    Returns:
        Configuration dictionary

    Raises:
        FileNotFoundError: If config file doesn't exist
        ValueError: If YAML is invalid
    """
    try:
        with open(config_path) as f:
            config: dict[str, Any] = yaml.safe_load(f)
        return config
    except FileNotFoundError:
        raise FileNotFoundError(
            f"Config file not found: {config_path}\n"
            "Please create it or specify a valid path with --config"
        ) from None
    except yaml.YAMLError as e:
        raise ValueError(f"Invalid YAML in config file {config_path}: {e}") from e

load_model_from_npz(npz_path, json_path)

Load model from NPZ+JSON format (production deployment)

Parameters:

Name Type Description Default
npz_path str

Path to .npz file with arrays

required
json_path str

Path to .json file with metadata

required

Returns:

Type Description
BinaryClassifier

Reconstructed BinaryClassifier instance

Notes

This function enables production deployment without pickle files. It reconstructs a fully functional BinaryClassifier from NPZ+JSON format. Uses strict Pydantic validation for metadata.

Source code in src/antibody_training_esm/core/training/serialization.py
def load_model_from_npz(npz_path: str, json_path: str) -> BinaryClassifier:
    """
    Load model from NPZ+JSON format (production deployment)

    Args:
        npz_path: Path to .npz file with arrays
        json_path: Path to .json file with metadata

    Returns:
        Reconstructed BinaryClassifier instance

    Notes:
        This function enables production deployment without pickle files.
        It reconstructs a fully functional BinaryClassifier from NPZ+JSON format.
        Uses strict Pydantic validation for metadata.
    """
    # Load arrays
    arrays = np.load(npz_path)
    coef = arrays["coef"]
    intercept = arrays["intercept"]
    classes = arrays["classes"]
    n_features_in = int(arrays["n_features_in"][0])
    n_iter = arrays["n_iter"]

    # Load metadata (Pydantic validates)
    with open(json_path) as f:
        metadata_dict = json.load(f)

    metadata = ModelArtifactMetadata.model_validate(metadata_dict)

    # Construct BinaryClassifier from metadata (Pydantic handles types)
    params = metadata.to_classifier_params()
    classifier = BinaryClassifier(params)

    # Restore fitted LogisticRegression state
    # Cast to Any because protocol doesn't enforce LogReg attributes
    inner_clf = cast(Any, classifier.classifier)
    inner_clf.classifier.coef_ = coef
    inner_clf.classifier.intercept_ = intercept
    inner_clf.classifier.classes_ = classes
    inner_clf.classifier.n_features_in_ = n_features_in
    inner_clf.classifier.n_iter_ = n_iter
    classifier.is_fitted = True

    return classifier

save_model(classifier, config, logger)

Save trained model in dual format (pickle + NPZ+JSON)

Models are saved in hierarchical directory structure

{model_save_dir}/{model_shortname}/{classifier_type}/{model_name}.*

Example

experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl

Parameters:

Name Type Description Default
classifier BinaryClassifier

Trained classifier

required
config TrainingPipelineConfig | dict[str, Any]

Configuration dictionary or Pydantic model

required
logger Logger

Logger instance

required

Returns:

Type Description
dict[str, str]

Dictionary with paths to saved files:

dict[str, str]

{ "pickle": "experiments/checkpoints/esm1v/logreg/model.pkl", "npz": "experiments/checkpoints/esm1v/logreg/model.npz", "config": "experiments/checkpoints/esm1v/logreg/model_config.json"

dict[str, str]

}

dict[str, str]

Empty dict if saving is disabled.

Source code in src/antibody_training_esm/core/training/serialization.py
def save_model(
    classifier: BinaryClassifier,
    config: "TrainingPipelineConfig | dict[str, Any]",
    logger: logging.Logger,
) -> dict[str, str]:
    """
    Save trained model in dual format (pickle + NPZ+JSON)

    Models are saved in hierarchical directory structure:
        {model_save_dir}/{model_shortname}/{classifier_type}/{model_name}.*

    Example:
        experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl

    Args:
        classifier: Trained classifier
        config: Configuration dictionary or Pydantic model
        logger: Logger instance

    Returns:
        Dictionary with paths to saved files:
        {
            "pickle": "experiments/checkpoints/esm1v/logreg/model.pkl",
            "npz": "experiments/checkpoints/esm1v/logreg/model.npz",
            "config": "experiments/checkpoints/esm1v/logreg/model_config.json"
        }
        Empty dict if saving is disabled.
    """
    from antibody_training_esm.models.config import TrainingPipelineConfig

    # Helper to extract config values regardless of type (Dict vs Pydantic)
    if isinstance(config, TrainingPipelineConfig):
        if not config.training.save_model:
            return {}
        model_name = config.training.model_name
        base_save_dir = config.training.model_save_dir
        model_shortname = config.model.name
        classifier_config = config.classifier.model_dump()
        classifier_strategy = config.classifier.strategy
        train_metrics = getattr(config, "train_metrics", None)
    else:
        if not config["training"]["save_model"]:
            return {}
        model_name = config["training"]["model_name"]
        base_save_dir = config["training"]["model_save_dir"]
        model_shortname = config["model"]["name"]
        classifier_config = config["classifier"]
        classifier_strategy = config["classifier"].get(
            "strategy", "logistic_regression"
        )
        train_metrics = config.get("train_metrics")

    # Auto-generate model_name if empty (e.g., "boughter_vh_biophysical_logreg")
    if not model_name:
        # Extract short names for readable filename
        model_short = extract_model_shortname(model_shortname)
        classifier_short = "logreg" if "logistic" in classifier_strategy else "xgboost"
        model_name = f"boughter_vh_{model_short}_{classifier_short}"
        logger.info(f"Auto-generated model_name: {model_name}")

    # Generate hierarchical directory path
    hierarchical_dir = get_hierarchical_model_dir(
        str(base_save_dir),
        model_shortname,
        classifier_config,
    )
    hierarchical_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Using hierarchical model directory: {hierarchical_dir}")

    base_path = hierarchical_dir / model_name

    # Format 1: Pickle checkpoint (research/debugging)
    pickle_path = f"{base_path}.pkl"
    with open(pickle_path, "wb") as f:
        pickle.dump(classifier, f)
    logger.info(f"Saved pickle checkpoint: {pickle_path}")

    # Format 2: Strategy-specific production serialization
    # Use duck typing to detect serialization method
    saved_paths = {"pickle": str(pickle_path)}

    if hasattr(classifier.classifier, "save_model"):
        # XGBoost native .xgb format (pickle-free)
        xgb_path = f"{base_path}.xgb"
        classifier.classifier.save_model(str(xgb_path))
        logger.info(f"Saved XGBoost native model: {xgb_path}")
        saved_paths["xgb"] = str(xgb_path)
    elif hasattr(classifier.classifier, "to_arrays"):
        # LogReg NPZ format (sklearn arrays)
        npz_path = f"{base_path}.npz"
        arrays = classifier.classifier.to_arrays()
        np.savez(npz_path, **cast(dict[str, Any], arrays))
        logger.info(f"Saved NPZ arrays: {npz_path}")
        saved_paths["npz"] = str(npz_path)
    else:
        # Fallback: legacy LogReg direct attribute access
        # Cast to Any because protocol doesn't enforce LogReg attributes
        inner_clf = cast(Any, classifier.classifier)
        npz_path = f"{base_path}.npz"
        np.savez(
            npz_path,
            coef=inner_clf.coef_,
            intercept=inner_clf.intercept_,
            classes=inner_clf.classes_,
            n_features_in=np.array([inner_clf.n_features_in_]),
            n_iter=inner_clf.n_iter_,
        )
        logger.info(f"Saved NPZ arrays (legacy): {npz_path}")
        saved_paths["npz"] = str(npz_path)

    # Format 3: JSON metadata (Pydantic)
    json_path = f"{base_path}_config.json"

    # Construct metadata from classifier (Pydantic handles serialization)
    metadata = ModelArtifactMetadata.from_classifier(classifier)

    # Add training metrics if available
    if train_metrics:
        metadata.training_metrics = train_metrics

    # Save as JSON (Pydantic handles type conversion)
    with open(json_path, "w") as f:
        # model_dump(mode='json') handles decimal/float serialization
        json.dump(metadata.model_dump(mode="json"), f, indent=2)

    logger.info(f"Saved JSON config: {json_path}")
    saved_paths["config"] = str(json_path)

    logger.info(f"Model saved successfully ({metadata.model_type} format)")
    return saved_paths

validate_config(config)

Validate config with Pydantic models.

Parameters:

Name Type Description Default
config dict[str, Any] | DictConfig

Raw dict or Hydra DictConfig

required

Returns:

Type Description
TrainingPipelineConfig

Validated TrainingPipelineConfig

Raises:

Type Description
ValidationError

If config is invalid

Source code in src/antibody_training_esm/core/trainer.py
def validate_config(config: dict[str, Any] | DictConfig) -> TrainingPipelineConfig:
    """
    Validate config with Pydantic models.

    Args:
        config: Raw dict or Hydra DictConfig

    Returns:
        Validated TrainingPipelineConfig

    Raises:
        ValidationError: If config is invalid
    """
    if isinstance(config, DictConfig):
        return TrainingPipelineConfig.from_hydra(config)
    result: TrainingPipelineConfig = TrainingPipelineConfig.model_validate(config)
    return result

setup_logging(config)

Setup logging from Pydantic config.

Parameters:

Name Type Description Default
config TrainingPipelineConfig

Validated TrainingPipelineConfig

required

Returns:

Type Description
Logger

Configured logger

Source code in src/antibody_training_esm/core/trainer.py
def setup_logging(config: TrainingPipelineConfig) -> logging.Logger:
    """
    Setup logging from Pydantic config.

    Args:
        config: Validated TrainingPipelineConfig

    Returns:
        Configured logger
    """
    from hydra.core.hydra_config import HydraConfig

    log_level = getattr(logging, config.training.log_level.upper())
    log_file = config.training.log_file

    # Hydra-aware path resolution (same as before)
    try:
        hydra_cfg = HydraConfig.get()
        output_dir = Path(hydra_cfg.runtime.output_dir)
        log_path = output_dir / log_file
        log_path.parent.mkdir(parents=True, exist_ok=True)
    except (ValueError, AttributeError):
        log_path = Path(log_file)
        if not log_path.is_absolute():
            log_path = Path.cwd() / log_file
        log_path.parent.mkdir(parents=True, exist_ok=True)

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        handlers=[logging.FileHandler(log_path), logging.StreamHandler()],
        force=True,
    )

    return logging.getLogger(__name__)

train_pipeline(cfg)

Core training pipeline with Pydantic validation.

Source code in src/antibody_training_esm/core/trainer.py
def train_pipeline(cfg: DictConfig) -> dict[str, Any]:
    """Core training pipeline with Pydantic validation."""
    # Validate config (now returns Pydantic model)
    config = validate_config(cfg)

    # Setup logging (accepts Pydantic model now)
    logger = setup_logging(config)

    logger.info("Starting antibody classification training")
    logger.info(f"Experiment: {config.experiment.name}")

    try:
        X_train, y_train = load_data(config)

        logger.info(f"Loaded {len(X_train)} training samples")

        # Phase B (Biophysical) filtering: remove sequences with ambiguous AAs ('X')
        # and stop codons ('*'). Biopython cannot handle these, unlike ESM.
        # Filter X and y together to keep labels aligned.
        if config.model.model_type == "biophysical":
            valid_indices = [
                i for i, seq in enumerate(X_train) if "X" not in seq and "*" not in seq
            ]
            dropped_count = len(X_train) - len(valid_indices)
            if dropped_count > 0:
                logger.warning(
                    f"Biophysical model requires strict amino acids. "
                    f"Dropping {dropped_count} sequences containing 'X' or '*'."
                )
                X_train = [X_train[i] for i in valid_indices]
                y_train = [y_train[i] for i in valid_indices]
                logger.info(f"Remaining samples after filtering: {len(X_train)}")

        # Resolve device (handles auto + explicit availability validation)
        device = resolve_device(config.model.device)
        config.model.device = cast(
            Literal["cpu", "cuda", "mps", "auto"], device
        )  # Persist resolved device
        if config.hardware and isinstance(config.hardware, dict):
            # Keep hardware section in sync when present
            config.hardware["device"] = device
        logger.info(f"Using device: {device}")

        # Initialize classifier
        classifier_params = {
            "model_name": config.model.name,
            "device": device,  # Use resolved device
            "batch_size": config.model.batch_size,
            "revision": config.model.revision,
            "model_type": config.model.model_type,  # ESM or AMPLIFY
            # Classifier strategy params
            "strategy": config.classifier.strategy,
            "C": config.classifier.C,
            "penalty": config.classifier.penalty,
            "solver": config.classifier.solver,
            "class_weight": config.classifier.class_weight,
            "max_iter": config.classifier.max_iter,
            "random_state": config.classifier.random_state,
            "n_estimators": config.classifier.n_estimators,
            "max_depth": config.classifier.max_depth,
            "learning_rate": config.classifier.learning_rate,
        }

        classifier = BinaryClassifier(classifier_params)

        # Get embeddings (cache_dir from config)
        cache_dir = config.data.embeddings_cache_dir
        X_train_embedded = get_or_create_embeddings(
            X_train, classifier.embedding_extractor, cache_dir, "train", logger
        )

        # Convert labels to numpy array
        y_train_array: np.ndarray = np.array(y_train)

        # Perform CV (returns CVResults Pydantic model)
        cv_results = perform_cross_validation(
            X_train_embedded,
            y_train_array,
            config,  # Passing Pydantic model
            logger,
        )

        # Save CV results
        try:
            from hydra.core.hydra_config import HydraConfig

            hydra_cfg = HydraConfig.get()
            cv_output_dir = Path(hydra_cfg.runtime.output_dir)
            experiment_name = config.experiment.name
            logger.info(f"Saving CV results to Hydra output dir: {cv_output_dir}")
        except (ValueError, AttributeError, ImportError):
            cv_output_dir = config.training.model_save_dir
            experiment_name = config.experiment.name
            logger.info(f"Running without Hydra, saving CV results to {cv_output_dir}")

        save_cv_results(cv_results, cv_output_dir, experiment_name, logger)

        # Train final model
        classifier.fit(X_train_embedded, y_train_array)

        # Evaluate (returns EvaluationMetrics Pydantic model)
        train_results = evaluate_model(
            classifier,
            X_train_embedded,
            y_train_array,
            "Training",
            list(config.training.metrics),  # Cast to list for type safety
            logger,
        )

        # Save model
        if config.training.save_model:
            # save_model expects config dict or object.
            # We'll pass Pydantic config.
            # Attach metrics to config for metadata saving
            config.train_metrics = train_results.model_dump(
                mode="json", exclude_none=True
            )
            model_paths = save_model(classifier, config, logger)
        else:
            model_paths = {}

        return {
            "train_metrics": train_results,
            "cv_metrics": cv_results,
            "config": config.model_dump(),  # Convert back to dict for serialization
            "model_paths": model_paths,
        }

    except Exception as e:
        logger.error(f"Training failed: {e}")
        raise

main(cfg)

Hydra entry point for CLI - DO NOT call directly in tests

This is the CLI entry point decorated with @hydra.main. It: - Automatically parses command-line overrides - Creates Hydra output directories - Saves composed config to .hydra/config.yaml - Delegates to train_pipeline() for core logic

Usage
Default config

python -m antibody_training_esm.core.trainer

With overrides

python -m antibody_training_esm.core.trainer model.batch_size=16

Multi-run sweep

python -m antibody_training_esm.core.trainer --multirun model=esm1v,esm2

Note

Tests should call train_pipeline() directly, not this function. This function is only for CLI usage with sys.argv parsing.

Source code in src/antibody_training_esm/core/trainer.py
@hydra.main(version_base=None, config_path="../conf", config_name="config")
def main(cfg: DictConfig) -> None:
    """
    Hydra entry point for CLI - DO NOT call directly in tests

    This is the CLI entry point decorated with @hydra.main. It:
    - Automatically parses command-line overrides
    - Creates Hydra output directories
    - Saves composed config to .hydra/config.yaml
    - Delegates to train_pipeline() for core logic

    Usage:
        # Default config
        python -m antibody_training_esm.core.trainer

        # With overrides
        python -m antibody_training_esm.core.trainer model.batch_size=16

        # Multi-run sweep
        python -m antibody_training_esm.core.trainer --multirun model=esm1v,esm2

    Note:
        Tests should call train_pipeline() directly, not this function.
        This function is only for CLI usage with sys.argv parsing.
    """
    logger = logging.getLogger(__name__)
    logger.info(f"Starting training with Hydra (experiment: {cfg.experiment.name})")

    try:
        # Call core training pipeline
        results = train_pipeline(cfg)

        # Log final results (access Pydantic fields)
        train_metrics = results["train_metrics"]
        cv_metrics = results["cv_metrics"]

        logger.info("=" * LOG_SEPARATOR_WIDTH)
        logger.info("TRAINING COMPLETE")
        logger.info("=" * LOG_SEPARATOR_WIDTH)
        logger.info(f"Train Accuracy: {train_metrics.accuracy:.4f}")
        logger.info(
            f"CV Accuracy: {cv_metrics.cv_accuracy['mean']:.4f} "
            f"(+/- {cv_metrics.cv_accuracy['std'] * 2:.4f})"
        )

        if results.get("model_paths"):
            logger.info(f"Model saved to: {results['model_paths']['pickle']}")

        logger.info("=" * LOG_SEPARATOR_WIDTH)

    except Exception as e:
        logger.error(f"Training failed: {str(e)}")
        raise