Skip to content

Shehata

shehata

Shehata Dataset Loader

Loads preprocessed Shehata antibody polyreactivity dataset (PSR assay).

IMPORTANT: This module is for LOADING preprocessed data, not for running the preprocessing pipeline. The preprocessing scripts that CREATE the data are in: preprocessing/shehata/step2_extract_fragments.py

Dataset characteristics: - Full antibodies (VH + VL) - 398 human antibodies (VH + VL) from healthy donor B cell subsets - Binary labels derived from PSR (polyspecificity reagent) scores - B cell subset metadata (IgG memory, IgM memory, Naïve, LLPCs) - 16 fragment types (full antibody)

Source: - data/test/shehata/raw/shehata-mmc2.xlsx

Reference: - Shehata et al. (2019), "Affinity Maturation Enhances Antibody Specificity but Compromises Conformational Stability" Supplementary Material mmc2.xlsx

Classes

ShehataDataset

Bases: AntibodyDataset

Loader for Shehata antibody dataset.

This class provides an interface to LOAD preprocessed Shehata dataset files. It does NOT run the preprocessing pipeline - use preprocessing/shehata/step2_extract_fragments.py for that.

The Shehata dataset contains human B cell-derived antibodies with PSR scores measuring polyreactivity across B cell subsets (IgG memory, IgM memory, Naïve, LLPCs).

Sakhnini et al. (2025) treat 7/398 (1.76%) as non-specific in their benchmark (Fig. S14C); by default this loader uses the corresponding 98.24th percentile threshold when binarizing PSR scores.

Source code in src/antibody_training_esm/datasets/shehata.py
class ShehataDataset(AntibodyDataset):
    """
    Loader for Shehata antibody dataset.

    This class provides an interface to LOAD preprocessed Shehata dataset files.
    It does NOT run the preprocessing pipeline - use preprocessing/shehata/step2_extract_fragments.py for that.

    The Shehata dataset contains human B cell-derived antibodies with PSR scores measuring
    polyreactivity across B cell subsets (IgG memory, IgM memory, Naïve, LLPCs).

    Sakhnini et al. (2025) treat 7/398 (1.76%) as non-specific in their benchmark
    (Fig. S14C); by default this loader uses the corresponding 98.24th percentile
    threshold when binarizing PSR scores.
    """

    # Default PSR threshold (98.24th percentile; 7/398 non-specific in Sakhnini et al. 2025 benchmark)
    DEFAULT_PSR_PERCENTILE = 0.9824

    def __init__(
        self, output_dir: Path | None = None, logger: logging.Logger | None = None
    ):
        """
        Initialize Shehata dataset loader.

        Args:
            output_dir: Directory containing preprocessed fragment files
            logger: Logger instance
        """
        super().__init__(
            dataset_name="shehata",
            output_dir=output_dir or SHEHATA_OUTPUT_DIR,
            logger=logger,
        )

    @classmethod
    def get_schema(cls) -> pa.DataFrameSchema:
        return get_shehata_schema()

    def get_fragment_types(self) -> list[str]:
        """
        Return full antibody fragment types.

        Shehata contains VH + VL sequences, so we generate all 16 fragment types.

        Returns:
            List of 16 full antibody fragment types
        """
        return self.FULL_ANTIBODY_FRAGMENTS

    def calculate_psr_threshold(
        self,
        psr_scores: pd.Series,
        percentile: float | None = None,
    ) -> float:
        """
        Calculate PSR score threshold for binary classification.

        Based on Sakhnini et al. (2025) benchmark treating 7/398 antibodies as non-specific.
        This is 1.76% = 98.24th percentile.

        Args:
            psr_scores: Series of PSR scores (numeric)
            percentile: Percentile to use (default: 0.9824 for 7/398)

        Returns:
            PSR threshold value
        """
        if percentile is None:
            percentile = self.DEFAULT_PSR_PERCENTILE

        threshold = psr_scores.quantile(percentile)

        self.logger.info("\nPSR Score Analysis:")
        self.logger.info(f"  Valid PSR scores: {psr_scores.notna().sum()}")
        self.logger.info(f"  Mean: {psr_scores.mean():.4f}")
        self.logger.info(f"  Median: {psr_scores.median():.4f}")
        self.logger.info(f"  75th percentile: {psr_scores.quantile(0.75):.4f}")
        self.logger.info(f"  95th percentile: {psr_scores.quantile(0.95):.4f}")
        self.logger.info(f"  Max: {psr_scores.max():.4f}")
        self.logger.info(f"\n  PSR = 0: {(psr_scores == 0).sum()} antibodies")
        self.logger.info(f"  PSR > 0: {(psr_scores > 0).sum()} antibodies")
        self.logger.info(
            "\n  Benchmark target (Sakhnini et al. 2025): 7/398 non-specific (~1.76%, 98.24th percentile)"
        )
        self.logger.info(f"  Calculated threshold: {threshold:.4f}")

        return threshold

    def load_data(
        self,
        excel_path: str | Path | None = None,
        psr_threshold: float | None = None,
        **_: Any,
    ) -> pd.DataFrame:
        """
        Load Shehata dataset from Excel file.

        Args:
            excel_path: Path to shehata-mmc2.xlsx
            psr_threshold: PSR score threshold for binary classification.
                          If None, calculates 98.24th percentile automatically.

        Returns:
            DataFrame with columns: id, VH_sequence, VL_sequence, label, psr_measurement, b_cell_subset

        Raises:
            FileNotFoundError: If Excel file not found
        """
        # Default path
        if excel_path is None:
            excel_path = SHEHATA_EXCEL_PATH

        excel_file = Path(excel_path)
        if not excel_file.exists():
            raise FileNotFoundError(
                f"Shehata Excel file not found: {excel_file}\n"
                f"Please ensure mmc2.xlsx is in data/test/shehata/raw/"
            )

        # Load Excel
        self.logger.info(f"Reading Excel file: {excel_file}")
        df = pd.read_excel(excel_file)

        # Validate dataset is not empty
        if len(df) == 0:
            raise ValueError(
                f"Loaded dataset is empty: {excel_file}\n"
                "The Excel file may be corrupted or truncated. "
                "Please check the file or re-run preprocessing."
            )

        self.logger.info(f"  Loaded {len(df)} rows, {len(df.columns)} columns")

        # Sanitize sequences (remove IMGT gap characters)
        self.logger.info("Sanitizing sequences (removing gaps)...")
        vh_original = df["VH Protein"].copy()
        vl_original = df["VL Protein"].copy()

        df["VH Protein"] = df["VH Protein"].apply(
            lambda x: self.sanitize_sequence(x) if pd.notna(x) else x
        )
        df["VL Protein"] = df["VL Protein"].apply(
            lambda x: self.sanitize_sequence(x) if pd.notna(x) else x
        )

        # Count gaps removed
        gaps_vh = sum(str(s).count("-") if pd.notna(s) else 0 for s in vh_original)
        gaps_vl = sum(str(s).count("-") if pd.notna(s) else 0 for s in vl_original)

        if gaps_vh > 0 or gaps_vl > 0:
            self.logger.info(f"  Removed {gaps_vh} gap characters from VH sequences")
            self.logger.info(f"  Removed {gaps_vl} gap characters from VL sequences")

        # Drop rows without sequences (Excel metadata/footnotes)
        before_drop = len(df)
        df = df.dropna(subset=["VH Protein", "VL Protein"], how="all")
        dropped = before_drop - len(df)
        if dropped:
            self.logger.info(
                f"  Dropped {dropped} rows without VH/VL sequences (metadata)"
            )

        # Convert PSR scores to numeric
        psr_numeric = pd.to_numeric(df["PSR Score"], errors="coerce")
        invalid_psr_mask = psr_numeric.isna()

        if invalid_psr_mask.any():
            dropped_ids = df.loc[invalid_psr_mask, "Clone name"].tolist()
            self.logger.warning(
                f"  Dropping {invalid_psr_mask.sum()} antibodies without numeric PSR scores: "
                f"{', '.join(dropped_ids)}"
            )
            df = df.loc[~invalid_psr_mask].reset_index(drop=True)
            psr_numeric = psr_numeric.loc[~invalid_psr_mask].reset_index(drop=True)

        # Calculate PSR threshold if not provided
        if psr_threshold is None:
            psr_threshold = self.calculate_psr_threshold(psr_numeric)
        else:
            self.logger.info(f"Using provided PSR threshold: {psr_threshold}")

        # Create standardized DataFrame
        df_output = pd.DataFrame(
            {
                "id": df["Clone name"],
                "VH_sequence": df["VH Protein"],
                "VL_sequence": df["VL Protein"],
                "label": (psr_numeric > psr_threshold).astype(int),
                "psr_measurement": psr_numeric,  # Renamed from psr_score to match schema
                "b_cell_subset": df["B cell subset"],
            }
        )

        # Create 'sequence' column for schema validation (use VH)
        if "sequence" not in df_output.columns and "VH_sequence" in df_output.columns:
            df_output["sequence"] = df_output["VH_sequence"]

        # Validate with Pandera
        df_output = self.validate_dataframe(df_output)

        # Label distribution
        self.logger.info("\nLabel distribution:")
        label_counts = df_output["label"].value_counts().sort_index()
        for label, count in label_counts.items():
            label_name = "Specific" if label == 0 else "Non-specific"
            percentage = (count / len(df_output)) * 100
            self.logger.info(
                f"  {label_name} (label={label}): {count} ({percentage:.1f}%)"
            )

        # B cell subset distribution
        self.logger.info("\nB cell subset distribution:")
        subset_counts = df_output["b_cell_subset"].value_counts()
        for subset, count in subset_counts.items():
            self.logger.info(f"  {subset}: {count}")

        return df_output
Functions
get_fragment_types()

Return full antibody fragment types.

Shehata contains VH + VL sequences, so we generate all 16 fragment types.

Returns:

Type Description
list[str]

List of 16 full antibody fragment types

Source code in src/antibody_training_esm/datasets/shehata.py
def get_fragment_types(self) -> list[str]:
    """
    Return full antibody fragment types.

    Shehata contains VH + VL sequences, so we generate all 16 fragment types.

    Returns:
        List of 16 full antibody fragment types
    """
    return self.FULL_ANTIBODY_FRAGMENTS
calculate_psr_threshold(psr_scores, percentile=None)

Calculate PSR score threshold for binary classification.

Based on Sakhnini et al. (2025) benchmark treating 7/398 antibodies as non-specific. This is 1.76% = 98.24th percentile.

Parameters:

Name Type Description Default
psr_scores Series

Series of PSR scores (numeric)

required
percentile float | None

Percentile to use (default: 0.9824 for 7/398)

None

Returns:

Type Description
float

PSR threshold value

Source code in src/antibody_training_esm/datasets/shehata.py
def calculate_psr_threshold(
    self,
    psr_scores: pd.Series,
    percentile: float | None = None,
) -> float:
    """
    Calculate PSR score threshold for binary classification.

    Based on Sakhnini et al. (2025) benchmark treating 7/398 antibodies as non-specific.
    This is 1.76% = 98.24th percentile.

    Args:
        psr_scores: Series of PSR scores (numeric)
        percentile: Percentile to use (default: 0.9824 for 7/398)

    Returns:
        PSR threshold value
    """
    if percentile is None:
        percentile = self.DEFAULT_PSR_PERCENTILE

    threshold = psr_scores.quantile(percentile)

    self.logger.info("\nPSR Score Analysis:")
    self.logger.info(f"  Valid PSR scores: {psr_scores.notna().sum()}")
    self.logger.info(f"  Mean: {psr_scores.mean():.4f}")
    self.logger.info(f"  Median: {psr_scores.median():.4f}")
    self.logger.info(f"  75th percentile: {psr_scores.quantile(0.75):.4f}")
    self.logger.info(f"  95th percentile: {psr_scores.quantile(0.95):.4f}")
    self.logger.info(f"  Max: {psr_scores.max():.4f}")
    self.logger.info(f"\n  PSR = 0: {(psr_scores == 0).sum()} antibodies")
    self.logger.info(f"  PSR > 0: {(psr_scores > 0).sum()} antibodies")
    self.logger.info(
        "\n  Benchmark target (Sakhnini et al. 2025): 7/398 non-specific (~1.76%, 98.24th percentile)"
    )
    self.logger.info(f"  Calculated threshold: {threshold:.4f}")

    return threshold
load_data(excel_path=None, psr_threshold=None, **_)

Load Shehata dataset from Excel file.

Parameters:

Name Type Description Default
excel_path str | Path | None

Path to shehata-mmc2.xlsx

None
psr_threshold float | None

PSR score threshold for binary classification. If None, calculates 98.24th percentile automatically.

None

Returns:

Type Description
DataFrame

DataFrame with columns: id, VH_sequence, VL_sequence, label, psr_measurement, b_cell_subset

Raises:

Type Description
FileNotFoundError

If Excel file not found

Source code in src/antibody_training_esm/datasets/shehata.py
def load_data(
    self,
    excel_path: str | Path | None = None,
    psr_threshold: float | None = None,
    **_: Any,
) -> pd.DataFrame:
    """
    Load Shehata dataset from Excel file.

    Args:
        excel_path: Path to shehata-mmc2.xlsx
        psr_threshold: PSR score threshold for binary classification.
                      If None, calculates 98.24th percentile automatically.

    Returns:
        DataFrame with columns: id, VH_sequence, VL_sequence, label, psr_measurement, b_cell_subset

    Raises:
        FileNotFoundError: If Excel file not found
    """
    # Default path
    if excel_path is None:
        excel_path = SHEHATA_EXCEL_PATH

    excel_file = Path(excel_path)
    if not excel_file.exists():
        raise FileNotFoundError(
            f"Shehata Excel file not found: {excel_file}\n"
            f"Please ensure mmc2.xlsx is in data/test/shehata/raw/"
        )

    # Load Excel
    self.logger.info(f"Reading Excel file: {excel_file}")
    df = pd.read_excel(excel_file)

    # Validate dataset is not empty
    if len(df) == 0:
        raise ValueError(
            f"Loaded dataset is empty: {excel_file}\n"
            "The Excel file may be corrupted or truncated. "
            "Please check the file or re-run preprocessing."
        )

    self.logger.info(f"  Loaded {len(df)} rows, {len(df.columns)} columns")

    # Sanitize sequences (remove IMGT gap characters)
    self.logger.info("Sanitizing sequences (removing gaps)...")
    vh_original = df["VH Protein"].copy()
    vl_original = df["VL Protein"].copy()

    df["VH Protein"] = df["VH Protein"].apply(
        lambda x: self.sanitize_sequence(x) if pd.notna(x) else x
    )
    df["VL Protein"] = df["VL Protein"].apply(
        lambda x: self.sanitize_sequence(x) if pd.notna(x) else x
    )

    # Count gaps removed
    gaps_vh = sum(str(s).count("-") if pd.notna(s) else 0 for s in vh_original)
    gaps_vl = sum(str(s).count("-") if pd.notna(s) else 0 for s in vl_original)

    if gaps_vh > 0 or gaps_vl > 0:
        self.logger.info(f"  Removed {gaps_vh} gap characters from VH sequences")
        self.logger.info(f"  Removed {gaps_vl} gap characters from VL sequences")

    # Drop rows without sequences (Excel metadata/footnotes)
    before_drop = len(df)
    df = df.dropna(subset=["VH Protein", "VL Protein"], how="all")
    dropped = before_drop - len(df)
    if dropped:
        self.logger.info(
            f"  Dropped {dropped} rows without VH/VL sequences (metadata)"
        )

    # Convert PSR scores to numeric
    psr_numeric = pd.to_numeric(df["PSR Score"], errors="coerce")
    invalid_psr_mask = psr_numeric.isna()

    if invalid_psr_mask.any():
        dropped_ids = df.loc[invalid_psr_mask, "Clone name"].tolist()
        self.logger.warning(
            f"  Dropping {invalid_psr_mask.sum()} antibodies without numeric PSR scores: "
            f"{', '.join(dropped_ids)}"
        )
        df = df.loc[~invalid_psr_mask].reset_index(drop=True)
        psr_numeric = psr_numeric.loc[~invalid_psr_mask].reset_index(drop=True)

    # Calculate PSR threshold if not provided
    if psr_threshold is None:
        psr_threshold = self.calculate_psr_threshold(psr_numeric)
    else:
        self.logger.info(f"Using provided PSR threshold: {psr_threshold}")

    # Create standardized DataFrame
    df_output = pd.DataFrame(
        {
            "id": df["Clone name"],
            "VH_sequence": df["VH Protein"],
            "VL_sequence": df["VL Protein"],
            "label": (psr_numeric > psr_threshold).astype(int),
            "psr_measurement": psr_numeric,  # Renamed from psr_score to match schema
            "b_cell_subset": df["B cell subset"],
        }
    )

    # Create 'sequence' column for schema validation (use VH)
    if "sequence" not in df_output.columns and "VH_sequence" in df_output.columns:
        df_output["sequence"] = df_output["VH_sequence"]

    # Validate with Pandera
    df_output = self.validate_dataframe(df_output)

    # Label distribution
    self.logger.info("\nLabel distribution:")
    label_counts = df_output["label"].value_counts().sort_index()
    for label, count in label_counts.items():
        label_name = "Specific" if label == 0 else "Non-specific"
        percentage = (count / len(df_output)) * 100
        self.logger.info(
            f"  {label_name} (label={label}): {count} ({percentage:.1f}%)"
        )

    # B cell subset distribution
    self.logger.info("\nB cell subset distribution:")
    subset_counts = df_output["b_cell_subset"].value_counts()
    for subset, count in subset_counts.items():
        self.logger.info(f"  {subset}: {count}")

    return df_output

Functions

load_shehata_data(excel_path=None, psr_threshold=None)

Convenience function to load preprocessed Shehata dataset.

IMPORTANT: This loads PREPROCESSED data. To preprocess raw data, use: preprocessing/shehata/step2_extract_fragments.py

Parameters:

Name Type Description Default
excel_path str | None

Path to shehata-mmc2.xlsx

None
psr_threshold float | None

PSR threshold for classification (None = auto-calculate)

None

Returns:

Type Description
DataFrame

DataFrame with preprocessed data

Example

from antibody_training_esm.datasets.shehata import load_shehata_data df = load_shehata_data() print(f"Loaded {len(df)} sequences")

Source code in src/antibody_training_esm/datasets/shehata.py
def load_shehata_data(
    excel_path: str | None = None,
    psr_threshold: float | None = None,
) -> pd.DataFrame:
    """
    Convenience function to load preprocessed Shehata dataset.

    IMPORTANT: This loads PREPROCESSED data. To preprocess raw data, use:
    preprocessing/shehata/step2_extract_fragments.py

    Args:
        excel_path: Path to shehata-mmc2.xlsx
        psr_threshold: PSR threshold for classification (None = auto-calculate)

    Returns:
        DataFrame with preprocessed data

    Example:
        >>> from antibody_training_esm.datasets.shehata import load_shehata_data
        >>> df = load_shehata_data()
        >>> print(f"Loaded {len(df)} sequences")
    """
    dataset = ShehataDataset()
    return dataset.load_data(excel_path=excel_path, psr_threshold=psr_threshold)