Skip to content

Classifier

classifier

Binary Classifier Module

Professional binary classifier for antibody sequences using ESM-1V, AMPLIFY, or biophysical embeddings. Includes sklearn compatibility, assay-specific thresholds, and model serialization.

Supported model types
  • "esm" (default): ESM-1v/ESM-2 models (facebook/esm1v_, facebook/esm2_)
  • "amplify": AMPLIFY 350M model (chandar-lab/AMPLIFY_350M)
  • "biophysical": Biophysical feature extractor (3-descriptor Track B)

Classes

EmbeddingExtractorProtocol

Bases: Protocol

Protocol for embedding extractors (ESM, AMPLIFY, etc.)

Source code in src/antibody_training_esm/core/classifier.py
class EmbeddingExtractorProtocol(Protocol):
    """Protocol for embedding extractors (ESM, AMPLIFY, etc.)"""

    model_name: str
    device: str
    batch_size: int
    revision: str
    max_length: int

    def embed_sequence(self, sequence: str) -> np.ndarray:
        """Extract embedding for a single sequence"""
        ...

    def extract_batch_embeddings(self, sequences: list[str]) -> np.ndarray:
        """Extract embeddings for multiple sequences"""
        ...
Functions
embed_sequence(sequence)

Extract embedding for a single sequence

Source code in src/antibody_training_esm/core/classifier.py
def embed_sequence(self, sequence: str) -> np.ndarray:
    """Extract embedding for a single sequence"""
    ...
extract_batch_embeddings(sequences)

Extract embeddings for multiple sequences

Source code in src/antibody_training_esm/core/classifier.py
def extract_batch_embeddings(self, sequences: list[str]) -> np.ndarray:
    """Extract embeddings for multiple sequences"""
    ...

BinaryClassifier

Binary classifier for protein sequences using ESM-1V or AMPLIFY embeddings

Source code in src/antibody_training_esm/core/classifier.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
class BinaryClassifier:
    """Binary classifier for protein sequences using ESM-1V or AMPLIFY embeddings"""

    # sklearn 1.7+ requires explicit estimator type for cross_val_score
    # This tells sklearn's validation logic that we're a classifier, not a regressor
    _estimator_type = "classifier"

    # Assay-specific thresholds (Novo Nordisk methodology)
    ASSAY_THRESHOLDS = {
        "ELISA": 0.5,  # Training data type (Boughter, Jain)
        "PSR": 0.5495,  # PSR assay type (Shehata, Harvey) - EXACT Novo parity
    }

    # Type annotation for embedding extractor (ESM or AMPLIFY)
    embedding_extractor: EmbeddingExtractorProtocol

    def __init__(self, params: dict[str, Any] | None = None, **kwargs: Any):
        """
        Initialize the binary classifier

        Args:
            params: Dictionary containing classifier parameters (legacy API)
            **kwargs: Individual parameters (for sklearn compatibility)

        Notes:
            Supports both dict-based (legacy) and kwargs-based (sklearn) initialization
        """
        # Support both dict-based (legacy) and kwargs-based (sklearn) initialization
        if params is None:
            params = kwargs

        # Validate required parameters (universal across all strategies)
        # Note: max_iter is LogReg-specific, removed from required params
        REQUIRED_PARAMS = ["random_state", "model_name", "device"]
        missing = [p for p in REQUIRED_PARAMS if p not in params]
        if missing:
            raise ValueError(
                f"Missing required parameters: {missing}. "
                f"BinaryClassifier requires: {REQUIRED_PARAMS}"
            )

        random_state = params["random_state"]
        batch_size = params.get(
            "batch_size", DEFAULT_BATCH_SIZE
        )  # Default if not provided
        revision = params.get("revision", "main")  # HF model revision (default: "main")

        # Select embedding extractor based on model_type (default: "esm" for backward compat)
        model_type = params.get("model_type", "esm")

        if model_type not in SUPPORTED_MODEL_TYPES:
            raise ValueError(
                f"Unknown model_type: '{model_type}'. "
                f"Supported types: {sorted(SUPPORTED_MODEL_TYPES)}"
            )

        if model_type == "amplify":
            # AMPLIFY requires special handling (batch_size=1, trust_remote_code, etc.)
            from antibody_training_esm.core.embeddings_amplify import (
                AMPLIFYEmbeddingExtractor,
            )

            self.embedding_extractor = AMPLIFYEmbeddingExtractor(
                params["model_name"],
                params["device"],
                batch_size,  # Will be forced to 1 by AMPLIFYEmbeddingExtractor
                revision=revision,
            )
        elif model_type == "biophysical":
            # Biophysical descriptors (Phase B)
            from antibody_training_esm.core.embeddings_biophysical import (
                BiophysicalEmbeddingExtractor,
            )

            self.embedding_extractor = BiophysicalEmbeddingExtractor(
                params["model_name"],
                params["device"],
                batch_size,
                revision=revision,
            )
        else:
            # ESM-1v, ESM-2 (default)
            self.embedding_extractor = ESMEmbeddingExtractor(
                params["model_name"], params["device"], batch_size, revision=revision
            )

        # Store model_type for get_params/set_params
        self._model_type = model_type

        # Use factory to create classifier strategy (supports LogReg, XGBoost, etc.)
        self.classifier: ClassifierStrategy = create_classifier(params)

        logger.info(
            "Classifier initialized: type=%s, params=%s",
            params.get("type", "logistic_regression"),
            self.classifier.get_params(),
        )

        # Store hyperparameters for recreation and sklearn compatibility
        self.random_state = random_state
        self.is_fitted = False
        self.device = self.embedding_extractor.device
        self.model_name = params["model_name"]
        self.batch_size = batch_size
        self.revision = revision  # Store HF model revision for reproducibility

        # Store all params for sklearn compatibility
        self._params = params

    def get_params(self, deep: bool = True) -> dict[str, Any]:
        """
        Get parameters for sklearn compatibility (required for cross_val_score)

        Args:
            deep: If True, return parameters for sub-estimators

        Returns:
            Dictionary of parameters (embedding params + classifier params)
        """
        # Merge embedding extractor params + classifier strategy params
        params = {
            "random_state": self.random_state,
            "model_name": self.model_name,
            "device": self.device,
            "batch_size": self.batch_size,
            "revision": self.revision,
            "model_type": self._model_type,
        }
        # Add classifier-specific params from strategy
        params.update(self.classifier.get_params(deep=deep))
        return params

    def set_params(self, **params: Any) -> "BinaryClassifier":
        """
        Set parameters for sklearn compatibility (required for cross_val_score)

        Args:
            **params: Parameters to set

        Returns:
            self

        Notes:
            This method updates parameters without destroying fitted state.
            If model_name or device changes, the embedding extractor is recreated.
            If classifier type changes, the classifier strategy is recreated.
        """
        # Update internal params dict
        self._params.update(params)

        # Track if we need to recreate components
        needs_extractor_reload = False
        needs_classifier_reload = False

        # Check which components need reloading
        embedding_params = {
            "model_name",
            "device",
            "batch_size",
            "revision",
            "model_type",
        }
        if any(key in params for key in embedding_params):
            needs_extractor_reload = True
            # Update instance attributes
            self.model_name = self._params.get("model_name", self.model_name)
            self.device = self._params.get("device", self.device)
            self.batch_size = self._params.get("batch_size", self.batch_size)
            self.revision = self._params.get("revision", self.revision)
            self._model_type = self._params.get("model_type", self._model_type)

        if "type" in params:
            needs_classifier_reload = True

        # Update random_state (used by both components)
        if "random_state" in params:
            self.random_state = params["random_state"]

        # Recreate embedding extractor if needed
        if needs_extractor_reload:
            logger.info(
                f"Recreating embedding extractor: model_name={self.model_name}, "
                f"device={self.device}, batch_size={self.batch_size}, model_type={self._model_type}"
            )
            if self._model_type == "amplify":
                from antibody_training_esm.core.embeddings_amplify import (
                    AMPLIFYEmbeddingExtractor,
                )

                self.embedding_extractor = AMPLIFYEmbeddingExtractor(
                    self.model_name,
                    self.device,
                    self.batch_size,
                    revision=self.revision,
                )
            elif self._model_type == "biophysical":
                from antibody_training_esm.core.embeddings_biophysical import (
                    BiophysicalEmbeddingExtractor,
                )

                self.embedding_extractor = BiophysicalEmbeddingExtractor(
                    self.model_name,
                    self.device,
                    self.batch_size,
                    revision=self.revision,
                )
            else:
                self.embedding_extractor = ESMEmbeddingExtractor(
                    self.model_name,
                    self.device,
                    self.batch_size,
                    revision=self.revision,
                )

        # Recreate classifier strategy if type changed
        if needs_classifier_reload:
            logger.info(f"Recreating classifier: type={params.get('type')}")
            self.classifier = create_classifier(self._params)
            self.is_fitted = False  # New classifier is unfitted
        else:
            # Update existing classifier params (e.g., C, penalty, solver)
            classifier_params = {
                k: v
                for k, v in params.items()
                if k not in embedding_params and k not in {"random_state", "type"}
            }
            if classifier_params:
                # For LogReg and other sklearn estimators, update attributes directly
                for key, value in classifier_params.items():
                    if hasattr(self.classifier, key):
                        setattr(self.classifier, key, value)
                        # Also update underlying sklearn classifier
                        if hasattr(self.classifier, "classifier") and hasattr(
                            self.classifier.classifier, key
                        ):
                            setattr(self.classifier.classifier, key, value)

        return self

    def fit(self, X: np.ndarray, y: np.ndarray) -> None:
        """
        Fit the classifier to the data

        Args:
            X: Array of ESM-1V embeddings
            y: Array of labels
        """
        # Fit the classifier directly on embeddings (no scaling per Novo methodology)
        self.classifier.fit(X, y)
        self.is_fitted = True

        # sklearn 1.7+ requires classes_ attribute for cross_val_score compatibility
        self.classes_ = self.classifier.classes_

        logger.info(f"Classifier fitted on {len(X)} samples")

    def predict(
        self, X: np.ndarray, threshold: float = 0.5, assay_type: str | None = None
    ) -> np.ndarray:
        """
        Predict labels for the data with optional assay-specific thresholds

        Args:
            X: Array of ESM-1V embeddings
            threshold: Decision threshold for classification (default: 0.5)
                      Ignored if assay_type is specified
            assay_type: Type of assay for dataset-specific thresholds. Options:
                       - 'ELISA': Use threshold=0.5 (for Jain, Boughter datasets)
                       - 'PSR': Use threshold=0.5495 (for Shehata, Harvey datasets)
                       - None: Use the threshold parameter

        Returns:
            Predicted labels

        Raises:
            ValueError: If classifier is not fitted or assay_type is unknown

        Notes:
            The model was trained on ELISA data (Boughter dataset). Different assay types
            measure different "spectrums" of non-specificity (Sakhnini et al. 2025, Section 2.7).
            Use assay_type='PSR' for PSR-based datasets to get calibrated predictions.
        """
        if not self.is_fitted:
            raise ValueError("Classifier must be fitted before making predictions")

        # Determine which threshold to use
        if assay_type is not None:
            if assay_type not in self.ASSAY_THRESHOLDS:
                raise ValueError(
                    f"Unknown assay_type '{assay_type}'. Must be one of: {list(self.ASSAY_THRESHOLDS.keys())}"
                )
            threshold = self.ASSAY_THRESHOLDS[assay_type]

        # Get probabilities and apply threshold
        probabilities = self.classifier.predict_proba(X)
        predictions: np.ndarray = (probabilities[:, 1] > threshold).astype(int)

        return predictions

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """
        Predict class probabilities for the data

        Args:
            X: Array of ESM-1V embeddings

        Returns:
            Predicted probabilities

        Raises:
            ValueError: If classifier is not fitted
        """
        if not self.is_fitted:
            raise ValueError("Classifier must be fitted before making predictions")

        result: np.ndarray = self.classifier.predict_proba(X)
        return result

    def score(self, X: np.ndarray, y: np.ndarray) -> float:
        """
        Return the mean accuracy on the given test data and labels

        Args:
            X: Array of ESM-1V embeddings
            y: Array of true labels

        Returns:
            Mean accuracy

        Raises:
            ValueError: If classifier is not fitted
        """
        if not self.is_fitted:
            raise ValueError("Classifier must be fitted before scoring")

        score: float = self.classifier.score(X, y)
        return score

    # ========================================================================
    # Backward Compatibility Properties (delegate to strategy)
    # ========================================================================

    @property
    def C(self) -> float:
        """Regularization parameter (LogReg only, for backward compatibility)"""
        return getattr(self.classifier, "C", 1.0)

    @property
    def penalty(self) -> str:
        """Regularization type (LogReg only, for backward compatibility)"""
        return getattr(self.classifier, "penalty", "l2")

    @property
    def solver(self) -> str:
        """Optimization algorithm (LogReg only, for backward compatibility)"""
        return getattr(self.classifier, "solver", "lbfgs")

    @property
    def class_weight(self) -> Any:
        """Class weights (LogReg only, for backward compatibility)"""
        return getattr(self.classifier, "class_weight", None)

    @property
    def max_iter(self) -> int:
        """Maximum iterations (LogReg only, for backward compatibility)"""
        return getattr(self.classifier, "max_iter", 1000)

    def __getstate__(self) -> dict[str, Any]:
        """Custom pickle method - don't save the ESM model"""
        state = self.__dict__.copy()
        # Remove the embedding_extractor (it will be recreated on load)
        state.pop("embedding_extractor", None)
        return state

    def __setstate__(self, state: dict[str, Any]) -> None:
        """Custom unpickle method - recreate ESM/AMPLIFY model with correct config"""
        self.__dict__.update(state)

        # Check for missing attributes from old model versions
        warnings_issued = []
        if not hasattr(self, "batch_size"):
            warnings_issued.append(f"batch_size (using default: {DEFAULT_BATCH_SIZE})")
        if not hasattr(self, "revision"):
            warnings_issued.append("revision (using default: 'main')")
        if not hasattr(self, "_model_type"):
            warnings_issued.append("_model_type (using default: 'esm')")

        if warnings_issued:
            import warnings

            warnings.warn(
                f"Loading old model missing attributes: {', '.join(warnings_issued)}. "
                "Predictions may differ from original model. Consider retraining with current version.",
                UserWarning,
                stacklevel=2,
            )

        # Recreate embedding extractor with fixed configuration
        batch_size = getattr(
            self, "batch_size", DEFAULT_BATCH_SIZE
        )  # Default if not stored (backwards compatibility)
        revision = getattr(
            self, "revision", "main"
        )  # Default if not stored (backwards compatibility)
        model_type = getattr(
            self, "_model_type", "esm"
        )  # Default if not stored (backwards compatibility)

        if model_type == "amplify":
            from antibody_training_esm.core.embeddings_amplify import (
                AMPLIFYEmbeddingExtractor,
            )

            self.embedding_extractor = AMPLIFYEmbeddingExtractor(
                self.model_name, self.device, batch_size, revision=revision
            )
        elif model_type == "biophysical":
            from antibody_training_esm.core.embeddings_biophysical import (
                BiophysicalEmbeddingExtractor,
            )

            self.embedding_extractor = BiophysicalEmbeddingExtractor(
                self.model_name, self.device, batch_size, revision=revision
            )
        else:
            self.embedding_extractor = ESMEmbeddingExtractor(
                self.model_name, self.device, batch_size, revision=revision
            )
Attributes
C property

Regularization parameter (LogReg only, for backward compatibility)

penalty property

Regularization type (LogReg only, for backward compatibility)

solver property

Optimization algorithm (LogReg only, for backward compatibility)

class_weight property

Class weights (LogReg only, for backward compatibility)

max_iter property

Maximum iterations (LogReg only, for backward compatibility)

Functions
get_params(deep=True)

Get parameters for sklearn compatibility (required for cross_val_score)

Parameters:

Name Type Description Default
deep bool

If True, return parameters for sub-estimators

True

Returns:

Type Description
dict[str, Any]

Dictionary of parameters (embedding params + classifier params)

Source code in src/antibody_training_esm/core/classifier.py
def get_params(self, deep: bool = True) -> dict[str, Any]:
    """
    Get parameters for sklearn compatibility (required for cross_val_score)

    Args:
        deep: If True, return parameters for sub-estimators

    Returns:
        Dictionary of parameters (embedding params + classifier params)
    """
    # Merge embedding extractor params + classifier strategy params
    params = {
        "random_state": self.random_state,
        "model_name": self.model_name,
        "device": self.device,
        "batch_size": self.batch_size,
        "revision": self.revision,
        "model_type": self._model_type,
    }
    # Add classifier-specific params from strategy
    params.update(self.classifier.get_params(deep=deep))
    return params
set_params(**params)

Set parameters for sklearn compatibility (required for cross_val_score)

Parameters:

Name Type Description Default
**params Any

Parameters to set

{}

Returns:

Type Description
BinaryClassifier

self

Notes

This method updates parameters without destroying fitted state. If model_name or device changes, the embedding extractor is recreated. If classifier type changes, the classifier strategy is recreated.

Source code in src/antibody_training_esm/core/classifier.py
def set_params(self, **params: Any) -> "BinaryClassifier":
    """
    Set parameters for sklearn compatibility (required for cross_val_score)

    Args:
        **params: Parameters to set

    Returns:
        self

    Notes:
        This method updates parameters without destroying fitted state.
        If model_name or device changes, the embedding extractor is recreated.
        If classifier type changes, the classifier strategy is recreated.
    """
    # Update internal params dict
    self._params.update(params)

    # Track if we need to recreate components
    needs_extractor_reload = False
    needs_classifier_reload = False

    # Check which components need reloading
    embedding_params = {
        "model_name",
        "device",
        "batch_size",
        "revision",
        "model_type",
    }
    if any(key in params for key in embedding_params):
        needs_extractor_reload = True
        # Update instance attributes
        self.model_name = self._params.get("model_name", self.model_name)
        self.device = self._params.get("device", self.device)
        self.batch_size = self._params.get("batch_size", self.batch_size)
        self.revision = self._params.get("revision", self.revision)
        self._model_type = self._params.get("model_type", self._model_type)

    if "type" in params:
        needs_classifier_reload = True

    # Update random_state (used by both components)
    if "random_state" in params:
        self.random_state = params["random_state"]

    # Recreate embedding extractor if needed
    if needs_extractor_reload:
        logger.info(
            f"Recreating embedding extractor: model_name={self.model_name}, "
            f"device={self.device}, batch_size={self.batch_size}, model_type={self._model_type}"
        )
        if self._model_type == "amplify":
            from antibody_training_esm.core.embeddings_amplify import (
                AMPLIFYEmbeddingExtractor,
            )

            self.embedding_extractor = AMPLIFYEmbeddingExtractor(
                self.model_name,
                self.device,
                self.batch_size,
                revision=self.revision,
            )
        elif self._model_type == "biophysical":
            from antibody_training_esm.core.embeddings_biophysical import (
                BiophysicalEmbeddingExtractor,
            )

            self.embedding_extractor = BiophysicalEmbeddingExtractor(
                self.model_name,
                self.device,
                self.batch_size,
                revision=self.revision,
            )
        else:
            self.embedding_extractor = ESMEmbeddingExtractor(
                self.model_name,
                self.device,
                self.batch_size,
                revision=self.revision,
            )

    # Recreate classifier strategy if type changed
    if needs_classifier_reload:
        logger.info(f"Recreating classifier: type={params.get('type')}")
        self.classifier = create_classifier(self._params)
        self.is_fitted = False  # New classifier is unfitted
    else:
        # Update existing classifier params (e.g., C, penalty, solver)
        classifier_params = {
            k: v
            for k, v in params.items()
            if k not in embedding_params and k not in {"random_state", "type"}
        }
        if classifier_params:
            # For LogReg and other sklearn estimators, update attributes directly
            for key, value in classifier_params.items():
                if hasattr(self.classifier, key):
                    setattr(self.classifier, key, value)
                    # Also update underlying sklearn classifier
                    if hasattr(self.classifier, "classifier") and hasattr(
                        self.classifier.classifier, key
                    ):
                        setattr(self.classifier.classifier, key, value)

    return self
fit(X, y)

Fit the classifier to the data

Parameters:

Name Type Description Default
X ndarray

Array of ESM-1V embeddings

required
y ndarray

Array of labels

required
Source code in src/antibody_training_esm/core/classifier.py
def fit(self, X: np.ndarray, y: np.ndarray) -> None:
    """
    Fit the classifier to the data

    Args:
        X: Array of ESM-1V embeddings
        y: Array of labels
    """
    # Fit the classifier directly on embeddings (no scaling per Novo methodology)
    self.classifier.fit(X, y)
    self.is_fitted = True

    # sklearn 1.7+ requires classes_ attribute for cross_val_score compatibility
    self.classes_ = self.classifier.classes_

    logger.info(f"Classifier fitted on {len(X)} samples")
predict(X, threshold=0.5, assay_type=None)

Predict labels for the data with optional assay-specific thresholds

Parameters:

Name Type Description Default
X ndarray

Array of ESM-1V embeddings

required
threshold float

Decision threshold for classification (default: 0.5) Ignored if assay_type is specified

0.5
assay_type str | None

Type of assay for dataset-specific thresholds. Options: - 'ELISA': Use threshold=0.5 (for Jain, Boughter datasets) - 'PSR': Use threshold=0.5495 (for Shehata, Harvey datasets) - None: Use the threshold parameter

None

Returns:

Type Description
ndarray

Predicted labels

Raises:

Type Description
ValueError

If classifier is not fitted or assay_type is unknown

Notes

The model was trained on ELISA data (Boughter dataset). Different assay types measure different "spectrums" of non-specificity (Sakhnini et al. 2025, Section 2.7). Use assay_type='PSR' for PSR-based datasets to get calibrated predictions.

Source code in src/antibody_training_esm/core/classifier.py
def predict(
    self, X: np.ndarray, threshold: float = 0.5, assay_type: str | None = None
) -> np.ndarray:
    """
    Predict labels for the data with optional assay-specific thresholds

    Args:
        X: Array of ESM-1V embeddings
        threshold: Decision threshold for classification (default: 0.5)
                  Ignored if assay_type is specified
        assay_type: Type of assay for dataset-specific thresholds. Options:
                   - 'ELISA': Use threshold=0.5 (for Jain, Boughter datasets)
                   - 'PSR': Use threshold=0.5495 (for Shehata, Harvey datasets)
                   - None: Use the threshold parameter

    Returns:
        Predicted labels

    Raises:
        ValueError: If classifier is not fitted or assay_type is unknown

    Notes:
        The model was trained on ELISA data (Boughter dataset). Different assay types
        measure different "spectrums" of non-specificity (Sakhnini et al. 2025, Section 2.7).
        Use assay_type='PSR' for PSR-based datasets to get calibrated predictions.
    """
    if not self.is_fitted:
        raise ValueError("Classifier must be fitted before making predictions")

    # Determine which threshold to use
    if assay_type is not None:
        if assay_type not in self.ASSAY_THRESHOLDS:
            raise ValueError(
                f"Unknown assay_type '{assay_type}'. Must be one of: {list(self.ASSAY_THRESHOLDS.keys())}"
            )
        threshold = self.ASSAY_THRESHOLDS[assay_type]

    # Get probabilities and apply threshold
    probabilities = self.classifier.predict_proba(X)
    predictions: np.ndarray = (probabilities[:, 1] > threshold).astype(int)

    return predictions
predict_proba(X)

Predict class probabilities for the data

Parameters:

Name Type Description Default
X ndarray

Array of ESM-1V embeddings

required

Returns:

Type Description
ndarray

Predicted probabilities

Raises:

Type Description
ValueError

If classifier is not fitted

Source code in src/antibody_training_esm/core/classifier.py
def predict_proba(self, X: np.ndarray) -> np.ndarray:
    """
    Predict class probabilities for the data

    Args:
        X: Array of ESM-1V embeddings

    Returns:
        Predicted probabilities

    Raises:
        ValueError: If classifier is not fitted
    """
    if not self.is_fitted:
        raise ValueError("Classifier must be fitted before making predictions")

    result: np.ndarray = self.classifier.predict_proba(X)
    return result
score(X, y)

Return the mean accuracy on the given test data and labels

Parameters:

Name Type Description Default
X ndarray

Array of ESM-1V embeddings

required
y ndarray

Array of true labels

required

Returns:

Type Description
float

Mean accuracy

Raises:

Type Description
ValueError

If classifier is not fitted

Source code in src/antibody_training_esm/core/classifier.py
def score(self, X: np.ndarray, y: np.ndarray) -> float:
    """
    Return the mean accuracy on the given test data and labels

    Args:
        X: Array of ESM-1V embeddings
        y: Array of true labels

    Returns:
        Mean accuracy

    Raises:
        ValueError: If classifier is not fitted
    """
    if not self.is_fitted:
        raise ValueError("Classifier must be fitted before scoring")

    score: float = self.classifier.score(X, y)
    return score

Functions