Skip to content

Artifact

artifact

Pydantic models for model artifacts and metrics.

This module defines the schema for: 1. Saved model metadata (JSON sidecar) 2. Evaluation metrics (accuracy, F1, etc.) 3. Cross-validation results

Classes

ModelArtifactMetadata

Bases: BaseModel

Metadata for saved model artifacts.

This model structures the JSON sidecar file that accompanies NPZ/XGB model files. It enables version checking and parameter reconstruction.

Source code in src/antibody_training_esm/models/artifact.py
class ModelArtifactMetadata(BaseModel):
    """
    Metadata for saved model artifacts.

    This model structures the JSON sidecar file that accompanies
    NPZ/XGB model files. It enables version checking and parameter
    reconstruction.
    """

    # Model architecture
    model_name: str = Field(
        ...,
        description="HuggingFace ESM model ID",
        examples=["facebook/esm1v_t33_650M_UR90S_1"],
    )

    model_type: Literal["logistic_regression", "xgboost", "random_forest"] = Field(
        ...,
        description="Classifier type",
    )

    sklearn_version: str = Field(
        ...,
        description="scikit-learn version used for training",
        examples=["1.3.0"],
    )

    # Classifier configuration (strategy-specific)
    classifier: dict[str, Any] = Field(
        ...,
        description="Full classifier config from to_dict() method",
    )

    # ESM embedding extractor params
    esm_model: str = Field(
        ...,
        description="ESM model name (redundant with model_name, kept for compat)",
    )

    esm_revision: str = Field(
        default="main",
        description="HuggingFace model revision (commit hash)",
    )

    batch_size: int = Field(
        default=16,
        ge=1,
        description="Batch size for embedding extraction",
    )

    device: str = Field(
        default="cpu",
        description="Device used during training",
    )

    # P1.2 fix: Track embedding extractor type for proper reconstruction
    embedding_model_type: Literal["esm", "amplify", "biophysical"] = Field(
        default="esm",
        description="Type of embedding extractor (esm, amplify, or biophysical)",
    )

    @field_validator("embedding_model_type", mode="before")
    @classmethod
    def infer_embedding_model_type(cls, v: str | None, info: Any) -> str:
        """
        Infer embedding_model_type from model_name for backward compatibility.

        P1.2 fix: Old JSON files won't have this field. Infer from model_name.
        """
        if v is not None:
            return v

        # Try to infer from model_name in the data being validated
        data = info.data if hasattr(info, "data") else {}
        model_name = data.get("model_name", "") or data.get("esm_model", "")

        if "biophysical" in model_name.lower():
            return "biophysical"
        elif "amplify" in model_name.lower():
            return "amplify"
        return "esm"  # Default for old files

    # Legacy flat fields (LogReg only, for backward compatibility)
    C: float | None = Field(
        default=None,
        description="LogReg: Inverse regularization strength",
    )

    penalty: Literal["l1", "l2"] | None = Field(
        default=None,
        description="LogReg: Regularization type",
    )

    solver: str | None = Field(
        default=None,
        description="LogReg: Optimization algorithm",
    )

    # Pydantic handles dict[int, float] keys automatically (converts string keys back to int)
    class_weight: Literal["balanced"] | dict[int, float] | None = Field(
        default=None,
        description="Class weighting strategy",
    )

    max_iter: int | None = Field(
        default=None,
        description="LogReg: Maximum iterations",
    )

    random_state: int | None = Field(
        default=None,
        description="Random seed",
    )

    # Optional metrics from training
    training_metrics: dict[str, float] | None = Field(
        default=None,
        description="Metrics from final training run",
    )

    @classmethod
    def from_classifier(cls, classifier: Any) -> "ModelArtifactMetadata":
        """
        Construct metadata from BinaryClassifier instance.

        Args:
            classifier: Trained BinaryClassifier

        Returns:
            ModelArtifactMetadata
        """
        import sklearn

        strategy_config = classifier.classifier.to_dict()
        classifier_type = strategy_config.get("type", "logistic_regression")

        # P1.2 fix: Extract embedding_model_type from classifier
        embedding_model_type = getattr(classifier, "_model_type", "esm")

        metadata_dict = {
            # Model architecture
            "model_name": classifier.model_name,
            "model_type": classifier_type,
            "sklearn_version": sklearn.__version__,
            # Classifier config (strategy-specific)
            "classifier": strategy_config,
            # ESM params
            "esm_model": classifier.model_name,
            "esm_revision": classifier.revision,
            "batch_size": classifier.batch_size,
            "device": classifier.device,
            # P1.2 fix: Persist embedding extractor type
            "embedding_model_type": embedding_model_type,
        }

        # Add legacy flat fields for LogReg (backward compat)
        if classifier_type == "logistic_regression":
            metadata_dict.update(
                {
                    "C": classifier.C,
                    "penalty": classifier.penalty,
                    "solver": classifier.solver,
                    "class_weight": classifier.class_weight,
                    "max_iter": classifier.max_iter,
                    "random_state": classifier.random_state,
                }
            )

        return cls.model_validate(metadata_dict)

    def to_classifier_params(self) -> dict[str, Any]:
        """
        Extract parameters for BinaryClassifier reconstruction.

        Returns:
            Dict of parameters for BinaryClassifier(...) init
        """
        params = {
            # ESM params
            "model_name": self.esm_model,
            "device": self.device,
            "batch_size": self.batch_size,
            "revision": self.esm_revision,
            # P1.2 fix: Include model_type for proper embedder reconstruction
            "model_type": self.embedding_model_type,
            # Classifier params
            **self.classifier,
        }

        # Overwrite with typed fields for LogReg to ensure correct types (e.g. int keys in dict)
        if self.model_type == "logistic_regression":
            params.update(
                {
                    "C": self.C,
                    "penalty": self.penalty,
                    "solver": self.solver,
                    "class_weight": self.class_weight,
                    "max_iter": self.max_iter,
                    "random_state": self.random_state,
                }
            )

        return params
Functions
infer_embedding_model_type(v, info) classmethod

Infer embedding_model_type from model_name for backward compatibility.

P1.2 fix: Old JSON files won't have this field. Infer from model_name.

Source code in src/antibody_training_esm/models/artifact.py
@field_validator("embedding_model_type", mode="before")
@classmethod
def infer_embedding_model_type(cls, v: str | None, info: Any) -> str:
    """
    Infer embedding_model_type from model_name for backward compatibility.

    P1.2 fix: Old JSON files won't have this field. Infer from model_name.
    """
    if v is not None:
        return v

    # Try to infer from model_name in the data being validated
    data = info.data if hasattr(info, "data") else {}
    model_name = data.get("model_name", "") or data.get("esm_model", "")

    if "biophysical" in model_name.lower():
        return "biophysical"
    elif "amplify" in model_name.lower():
        return "amplify"
    return "esm"  # Default for old files
from_classifier(classifier) classmethod

Construct metadata from BinaryClassifier instance.

Parameters:

Name Type Description Default
classifier Any

Trained BinaryClassifier

required

Returns:

Type Description
ModelArtifactMetadata

ModelArtifactMetadata

Source code in src/antibody_training_esm/models/artifact.py
@classmethod
def from_classifier(cls, classifier: Any) -> "ModelArtifactMetadata":
    """
    Construct metadata from BinaryClassifier instance.

    Args:
        classifier: Trained BinaryClassifier

    Returns:
        ModelArtifactMetadata
    """
    import sklearn

    strategy_config = classifier.classifier.to_dict()
    classifier_type = strategy_config.get("type", "logistic_regression")

    # P1.2 fix: Extract embedding_model_type from classifier
    embedding_model_type = getattr(classifier, "_model_type", "esm")

    metadata_dict = {
        # Model architecture
        "model_name": classifier.model_name,
        "model_type": classifier_type,
        "sklearn_version": sklearn.__version__,
        # Classifier config (strategy-specific)
        "classifier": strategy_config,
        # ESM params
        "esm_model": classifier.model_name,
        "esm_revision": classifier.revision,
        "batch_size": classifier.batch_size,
        "device": classifier.device,
        # P1.2 fix: Persist embedding extractor type
        "embedding_model_type": embedding_model_type,
    }

    # Add legacy flat fields for LogReg (backward compat)
    if classifier_type == "logistic_regression":
        metadata_dict.update(
            {
                "C": classifier.C,
                "penalty": classifier.penalty,
                "solver": classifier.solver,
                "class_weight": classifier.class_weight,
                "max_iter": classifier.max_iter,
                "random_state": classifier.random_state,
            }
        )

    return cls.model_validate(metadata_dict)
to_classifier_params()

Extract parameters for BinaryClassifier reconstruction.

Returns:

Type Description
dict[str, Any]

Dict of parameters for BinaryClassifier(...) init

Source code in src/antibody_training_esm/models/artifact.py
def to_classifier_params(self) -> dict[str, Any]:
    """
    Extract parameters for BinaryClassifier reconstruction.

    Returns:
        Dict of parameters for BinaryClassifier(...) init
    """
    params = {
        # ESM params
        "model_name": self.esm_model,
        "device": self.device,
        "batch_size": self.batch_size,
        "revision": self.esm_revision,
        # P1.2 fix: Include model_type for proper embedder reconstruction
        "model_type": self.embedding_model_type,
        # Classifier params
        **self.classifier,
    }

    # Overwrite with typed fields for LogReg to ensure correct types (e.g. int keys in dict)
    if self.model_type == "logistic_regression":
        params.update(
            {
                "C": self.C,
                "penalty": self.penalty,
                "solver": self.solver,
                "class_weight": self.class_weight,
                "max_iter": self.max_iter,
                "random_state": self.random_state,
            }
        )

    return params

EvaluationMetrics

Bases: BaseModel

Evaluation metrics for a single dataset.

Used for training set, test set, and cross-validation fold results.

Source code in src/antibody_training_esm/models/artifact.py
class EvaluationMetrics(BaseModel):
    """
    Evaluation metrics for a single dataset.

    Used for training set, test set, and cross-validation fold results.
    """

    accuracy: float = Field(
        ...,
        ge=0.0,
        le=1.0,
        description="Classification accuracy (0-1)",
    )

    precision: float | None = Field(
        default=None,
        ge=0.0,
        le=1.0,
        description="Precision (positive predictive value)",
    )

    recall: float | None = Field(
        default=None,
        ge=0.0,
        le=1.0,
        description="Recall (sensitivity, true positive rate)",
    )

    f1: float | None = Field(
        default=None,
        ge=0.0,
        le=1.0,
        description="F1 score (harmonic mean of precision and recall)",
    )

    roc_auc: float | None = Field(
        default=None,
        ge=0.0,
        le=1.0,
        description="Area under ROC curve",
    )

    # Optional confusion matrix
    confusion_matrix: list[list[int]] | None = Field(
        default=None,
        description="Confusion matrix [[TN, FP], [FN, TP]]",
    )

    # Dataset metadata
    dataset_name: str | None = Field(
        default=None,
        description="Name of evaluated dataset (e.g., 'Jain', 'Training')",
    )

    n_samples: int | None = Field(
        default=None,
        ge=0,
        description="Number of samples in dataset",
    )

    model_config = {
        "json_schema_extra": {
            "examples": [
                {
                    "accuracy": 0.85,
                    "precision": 0.83,
                    "recall": 0.88,
                    "f1": 0.85,
                    "roc_auc": 0.90,
                    "confusion_matrix": [[82, 18], [12, 88]],
                    "dataset_name": "Example",
                    "n_samples": 200,
                }
            ]
        }
    }

    @classmethod
    def from_sklearn_metrics(
        cls,
        y_true: np.ndarray,
        y_pred: np.ndarray,
        y_proba: np.ndarray | None = None,
        dataset_name: str | None = None,
    ) -> "EvaluationMetrics":
        """
        Construct metrics from sklearn predictions.

        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            y_proba: Predicted probabilities (for ROC-AUC)
            dataset_name: Name of dataset

        Returns:
            EvaluationMetrics
        """
        from sklearn.metrics import (
            accuracy_score,
            confusion_matrix,
            f1_score,
            precision_score,
            recall_score,
            roc_auc_score,
        )

        metrics_dict = {
            "accuracy": float(accuracy_score(y_true, y_pred)),
            "precision": float(precision_score(y_true, y_pred, zero_division=0)),
            "recall": float(recall_score(y_true, y_pred, zero_division=0)),
            "f1": float(f1_score(y_true, y_pred, zero_division=0)),
            "dataset_name": dataset_name,
            "n_samples": len(y_true),
            "confusion_matrix": confusion_matrix(y_true, y_pred).tolist(),
        }

        # ROC-AUC requires probabilities
        if y_proba is not None:
            try:
                # Check if y_proba has 2 columns (binary classification)
                if y_proba.ndim == 2 and y_proba.shape[1] >= 2:
                    score = roc_auc_score(y_true, y_proba[:, 1])
                else:
                    # Fallback for 1D array if passed incorrectly
                    score = roc_auc_score(y_true, y_proba)
                metrics_dict["roc_auc"] = float(score)
            except ValueError:
                # ROC AUC might fail if only one class is present in y_true
                metrics_dict["roc_auc"] = None

        return cls.model_validate(metrics_dict)
Functions
from_sklearn_metrics(y_true, y_pred, y_proba=None, dataset_name=None) classmethod

Construct metrics from sklearn predictions.

Parameters:

Name Type Description Default
y_true ndarray

Ground truth labels

required
y_pred ndarray

Predicted labels

required
y_proba ndarray | None

Predicted probabilities (for ROC-AUC)

None
dataset_name str | None

Name of dataset

None

Returns:

Type Description
EvaluationMetrics

EvaluationMetrics

Source code in src/antibody_training_esm/models/artifact.py
@classmethod
def from_sklearn_metrics(
    cls,
    y_true: np.ndarray,
    y_pred: np.ndarray,
    y_proba: np.ndarray | None = None,
    dataset_name: str | None = None,
) -> "EvaluationMetrics":
    """
    Construct metrics from sklearn predictions.

    Args:
        y_true: Ground truth labels
        y_pred: Predicted labels
        y_proba: Predicted probabilities (for ROC-AUC)
        dataset_name: Name of dataset

    Returns:
        EvaluationMetrics
    """
    from sklearn.metrics import (
        accuracy_score,
        confusion_matrix,
        f1_score,
        precision_score,
        recall_score,
        roc_auc_score,
    )

    metrics_dict = {
        "accuracy": float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred, zero_division=0)),
        "recall": float(recall_score(y_true, y_pred, zero_division=0)),
        "f1": float(f1_score(y_true, y_pred, zero_division=0)),
        "dataset_name": dataset_name,
        "n_samples": len(y_true),
        "confusion_matrix": confusion_matrix(y_true, y_pred).tolist(),
    }

    # ROC-AUC requires probabilities
    if y_proba is not None:
        try:
            # Check if y_proba has 2 columns (binary classification)
            if y_proba.ndim == 2 and y_proba.shape[1] >= 2:
                score = roc_auc_score(y_true, y_proba[:, 1])
            else:
                # Fallback for 1D array if passed incorrectly
                score = roc_auc_score(y_true, y_proba)
            metrics_dict["roc_auc"] = float(score)
        except ValueError:
            # ROC AUC might fail if only one class is present in y_true
            metrics_dict["roc_auc"] = None

    return cls.model_validate(metrics_dict)

CVResults

Bases: BaseModel

Cross-validation results with mean and std for each metric.

Aggregates metrics across all CV folds.

Source code in src/antibody_training_esm/models/artifact.py
class CVResults(BaseModel):
    """
    Cross-validation results with mean and std for each metric.

    Aggregates metrics across all CV folds.
    """

    cv_accuracy: dict[Literal["mean", "std"], float] = Field(
        ...,
        description="Mean and std of accuracy across folds",
    )

    cv_precision: dict[Literal["mean", "std"], float] | None = Field(
        default=None,
        description="Mean and std of precision",
    )

    cv_recall: dict[Literal["mean", "std"], float] | None = Field(
        default=None,
        description="Mean and std of recall",
    )

    cv_f1: dict[Literal["mean", "std"], float] | None = Field(
        default=None,
        description="Mean and std of F1 score",
    )

    cv_roc_auc: dict[Literal["mean", "std"], float] | None = Field(
        default=None,
        description="Mean and std of ROC-AUC",
    )

    n_splits: int = Field(
        ...,
        ge=2,
        description="Number of cross-validation folds",
    )

    # Optional: per-fold results
    fold_results: list[EvaluationMetrics] | None = Field(
        default=None,
        description="Metrics for each individual fold",
    )

    model_config = {
        "json_schema_extra": {
            "examples": [
                {
                    "cv_accuracy": {"mean": 0.82, "std": 0.05},
                    "cv_precision": {"mean": 0.78, "std": 0.06},
                    "cv_recall": {"mean": 0.85, "std": 0.04},
                    "cv_f1": {"mean": 0.81, "std": 0.05},
                    "cv_roc_auc": {"mean": 0.87, "std": 0.03},
                    "n_splits": 10,
                }
            ]
        }
    }

    @classmethod
    def from_sklearn_cv_results(
        cls,
        cv_scores: dict[str, list[float] | np.ndarray],
        n_splits: int,
    ) -> "CVResults":
        """
        Construct CVResults from sklearn cross_validate output.

        Args:
            cv_scores: Dict like {"test_accuracy": [...], "test_f1": [...]}
            n_splits: Number of folds

        Returns:
            CVResults
        """
        results_dict: dict[str, Any] = {"n_splits": n_splits}

        # Map sklearn metric names to our field names
        metric_map = {
            "test_accuracy": "cv_accuracy",
            "test_precision": "cv_precision",
            "test_recall": "cv_recall",
            "test_f1": "cv_f1",
            "test_roc_auc": "cv_roc_auc",
        }

        for sklearn_name, pydantic_name in metric_map.items():
            if sklearn_name in cv_scores:
                scores = cv_scores[sklearn_name]
                # Handle potential NaN in scores
                valid_scores = [s for s in scores if not np.isnan(s)]

                if valid_scores:
                    results_dict[pydantic_name] = {
                        "mean": float(np.mean(valid_scores)),
                        "std": float(np.std(valid_scores)),
                    }
                else:
                    results_dict[pydantic_name] = {
                        "mean": 0.0,
                        "std": 0.0,
                    }

        return cls.model_validate(results_dict)
Functions
from_sklearn_cv_results(cv_scores, n_splits) classmethod

Construct CVResults from sklearn cross_validate output.

Parameters:

Name Type Description Default
cv_scores dict[str, list[float] | ndarray]

Dict like {"test_accuracy": [...], "test_f1": [...]}

required
n_splits int

Number of folds

required

Returns:

Type Description
CVResults

CVResults

Source code in src/antibody_training_esm/models/artifact.py
@classmethod
def from_sklearn_cv_results(
    cls,
    cv_scores: dict[str, list[float] | np.ndarray],
    n_splits: int,
) -> "CVResults":
    """
    Construct CVResults from sklearn cross_validate output.

    Args:
        cv_scores: Dict like {"test_accuracy": [...], "test_f1": [...]}
        n_splits: Number of folds

    Returns:
        CVResults
    """
    results_dict: dict[str, Any] = {"n_splits": n_splits}

    # Map sklearn metric names to our field names
    metric_map = {
        "test_accuracy": "cv_accuracy",
        "test_precision": "cv_precision",
        "test_recall": "cv_recall",
        "test_f1": "cv_f1",
        "test_roc_auc": "cv_roc_auc",
    }

    for sklearn_name, pydantic_name in metric_map.items():
        if sklearn_name in cv_scores:
            scores = cv_scores[sklearn_name]
            # Handle potential NaN in scores
            valid_scores = [s for s in scores if not np.isnan(s)]

            if valid_scores:
                results_dict[pydantic_name] = {
                    "mean": float(np.mean(valid_scores)),
                    "std": float(np.std(valid_scores)),
                }
            else:
                results_dict[pydantic_name] = {
                    "mean": 0.0,
                    "std": 0.0,
                }

    return cls.model_validate(results_dict)