Skip to content

Embeddings

embeddings

ESM Embedding Module

Professional module for ESM-1V protein sequence embedding extraction. Handles batch processing, GPU memory management, and validation.

Classes

ESMEmbeddingExtractor

Extract ESM-1V embeddings for protein sequences with proper batching and GPU management

Source code in src/antibody_training_esm/core/embeddings.py
class ESMEmbeddingExtractor:
    """Extract ESM-1V embeddings for protein sequences with proper batching and GPU management"""

    def __init__(
        self,
        model_name: str,
        device: str,
        batch_size: int = DEFAULT_BATCH_SIZE,
        max_length: int = DEFAULT_MAX_SEQ_LENGTH,
        revision: str = "main",
    ):
        """
        Initialize ESM embedding extractor

        Args:
            model_name: HuggingFace model identifier (e.g., 'facebook/esm1v_t33_650M_UR90S_1')
            device: Device to run model on ('cpu', 'cuda', or 'mps')
            batch_size: Number of sequences to process per batch
            max_length: Maximum sequence length for tokenizer truncation/padding
            revision: HuggingFace model revision (commit SHA or branch name) for reproducibility
        """
        self.model_name = model_name
        self.device = device
        self.batch_size = batch_size
        self.max_length = max_length
        self.revision = revision

        # Load model with output_hidden_states enabled + pinned revision for reproducibility
        self.model = AutoModel.from_pretrained(
            model_name,
            output_hidden_states=True,
            revision=revision,  # nosec B615 - Pinned to specific version for scientific reproducibility
        )
        self.model.to(device)
        self.model.eval()
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            revision=revision,  # nosec B615 - Pinned to specific version for scientific reproducibility
        )  # type: ignore[no-untyped-call]  # HuggingFace transformers lacks type stubs
        logger.info(
            f"ESM model {model_name} (revision={revision}) loaded on {device} "
            f"with batch_size={batch_size} and max_length={max_length}"
        )

    def embed_sequence(self, sequence: str) -> np.ndarray:
        """
        Extract ESM-1V embedding for a single protein sequence

        Args:
            sequence: Amino acid sequence string

        Returns:
            Embedding vector as numpy array

        Raises:
            ValueError: If sequence contains invalid amino acids or is too short
        """
        try:
            # Validate sequence (20 standard amino acids + X for unknown/ambiguous)
            # X is supported by ESM tokenizer for ambiguous residues
            valid_aas = set("ACDEFGHIKLMNPQRSTVWYX")
            sequence = sequence.upper().strip()

            if not all(aa in valid_aas for aa in sequence):
                raise ValueError("Invalid amino acid characters in sequence")

            if len(sequence) < 1:
                raise ValueError("Sequence too short")

            # Tokenize the sequence
            inputs = self.tokenizer(
                sequence,
                return_tensors="pt",
                truncation=True,
                max_length=self.max_length,
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            # Get embeddings
            with torch.no_grad():
                outputs = self.model(**inputs, output_hidden_states=True)
                embeddings = outputs.hidden_states[-1]  # (batch, seq_len, hidden_dim)

                # Use attention mask to properly exclude padding and special tokens
                attention_mask = inputs["attention_mask"].unsqueeze(
                    -1
                )  # (batch, seq_len, 1)

                # Mask out special tokens (first and last)
                attention_mask[:, 0, :] = 0  # CLS token
                attention_mask[:, -1, :] = 0  # EOS token

                # Masked mean pooling
                masked_embeddings = embeddings * attention_mask
                sum_embeddings = masked_embeddings.sum(dim=1)  # Sum over sequence
                sum_mask = attention_mask.sum(dim=1)  # Count valid tokens

                # Prevent division by zero (NaN embeddings)
                if sum_mask.item() == 0:
                    raise ValueError(
                        f"Attention mask is all zeros for sequence (length: {len(sequence)}). "
                        f"Sequence preview: '{sequence[:SEQUENCE_PREVIEW_LENGTH]}...'. "
                        "This typically indicates an empty or invalid sequence after masking."
                    )

                mean_embeddings = sum_embeddings / sum_mask  # Average

                result: np.ndarray = mean_embeddings.squeeze(0).cpu().numpy()
                return result

        except Exception as e:
            # Add sequence context to error message (truncate for readability)
            seq_preview = (
                sequence[:SEQUENCE_PREVIEW_LENGTH] + "..."
                if len(sequence) > SEQUENCE_PREVIEW_LENGTH
                else sequence
            )
            logger.error(
                f"Error getting embeddings for sequence (length={len(sequence)}): {seq_preview}"
            )
            raise RuntimeError(
                f"Failed to extract embedding for sequence of length {len(sequence)}: {seq_preview}"
            ) from e

    def extract_batch_embeddings(self, sequences: list[str]) -> np.ndarray:
        """
        Extract embeddings for multiple sequences using efficient batching

        Args:
            sequences: List of amino acid sequence strings

        Returns:
            Array of embeddings with shape (n_sequences, embedding_dim)
        """
        embeddings_list = []

        logger.info(
            f"Extracting embeddings for {len(sequences)} sequences with batch_size={self.batch_size}..."
        )

        # Process sequences in batches
        num_batches = (len(sequences) + self.batch_size - 1) // self.batch_size

        for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
            start_idx = batch_idx * self.batch_size
            end_idx = min(start_idx + self.batch_size, len(sequences))
            batch_sequences = sequences[start_idx:end_idx]

            try:
                # Validate and clean sequences
                valid_aas = set("ACDEFGHIKLMNPQRSTVWYX")
                cleaned_sequences: list[str] = []
                invalid_sequences: list[
                    tuple[int, str, str]
                ] = []  # (index, sequence, reason)

                for seq_idx, seq in enumerate(batch_sequences):
                    seq = seq.upper().strip()
                    global_idx = start_idx + seq_idx

                    # Check for empty/short sequences
                    if len(seq) < 1:
                        invalid_sequences.append(
                            (global_idx, seq, "empty or too short")
                        )
                        continue

                    # Check for invalid amino acids
                    invalid_chars = [aa for aa in seq if aa not in valid_aas]
                    if invalid_chars:
                        reason = f"invalid characters: {set(invalid_chars)}"
                        invalid_sequences.append(
                            (global_idx, seq[:SEQUENCE_PREVIEW_LENGTH], reason)
                        )
                        continue

                    cleaned_sequences.append(seq)

                # If any sequences are invalid, fail immediately
                if invalid_sequences:
                    error_details = "\n".join(
                        f"  Index {idx}: '{seq}...' ({reason})"
                        for idx, seq, reason in invalid_sequences[:ERROR_PREVIEW_LIMIT]
                    )
                    total_invalid = len(invalid_sequences)
                    raise ValueError(
                        f"Found {total_invalid} invalid sequence(s) in batch {batch_idx}:\n{error_details}"
                        + (
                            f"\n  ... and {total_invalid - ERROR_PREVIEW_LIMIT} more"
                            if total_invalid > ERROR_PREVIEW_LIMIT
                            else ""
                        )
                    )

                # Tokenize the batch with padding
                inputs = self.tokenizer(
                    cleaned_sequences,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=self.max_length,
                )
                inputs = {k: v.to(self.device) for k, v in inputs.items()}

                # Get embeddings for the batch
                with torch.no_grad():
                    outputs = self.model(**inputs, output_hidden_states=True)
                    embeddings = outputs.hidden_states[
                        -1
                    ]  # (batch, seq_len, hidden_dim)

                    # Use attention mask to properly exclude padding and special tokens
                    attention_mask = inputs["attention_mask"].unsqueeze(
                        -1
                    )  # (batch, seq_len, 1)

                    # Mask out special tokens (first and last)
                    attention_mask[:, 0, :] = 0  # CLS token
                    attention_mask[:, -1, :] = 0  # EOS token

                    # Masked mean pooling
                    masked_embeddings = embeddings * attention_mask
                    sum_embeddings = masked_embeddings.sum(dim=1)  # Sum over sequence
                    sum_mask = attention_mask.sum(dim=1)  # Count valid tokens

                    # Prevent division by zero (NaN embeddings)
                    # Use clamp to avoid zero divisors (min valid tokens = 1)
                    sum_mask_safe = sum_mask.clamp(min=1e-9)
                    mean_embeddings = sum_embeddings / sum_mask_safe  # Average

                    # Check if any sequences had zero mask (would produce near-zero or invalid embeddings)
                    zero_mask_indices = (
                        (sum_mask == 0).any(dim=1).nonzero(as_tuple=True)[0]
                    )
                    if len(zero_mask_indices) > 0:
                        bad_seqs = [
                            cleaned_sequences[i.item()][:SEQUENCE_PREVIEW_LENGTH]
                            for i in zero_mask_indices[:3]
                        ]
                        raise ValueError(
                            f"Found {len(zero_mask_indices)} sequence(s) with zero attention mask in batch {batch_idx}. "
                            f"Sample sequences: {bad_seqs}. This indicates empty/invalid sequences after masking."
                        )

                    # Convert to numpy and add to list
                    batch_embeddings = mean_embeddings.cpu().numpy()
                    for emb in batch_embeddings:
                        embeddings_list.append(emb)

                # Clear GPU cache periodically to prevent OOM
                if (batch_idx + 1) % GPU_CACHE_CLEAR_INTERVAL == 0:
                    self._clear_gpu_cache()

            except Exception as e:
                logger.error(
                    f"CRITICAL: Failed to process batch {batch_idx} (sequences {start_idx}-{end_idx}): {e}"
                )
                logger.error(
                    f"First sequence in failed batch: {batch_sequences[0][:100]}..."
                )
                raise RuntimeError(
                    f"Batch processing failed at batch {batch_idx}. Cannot continue with corrupted embeddings. "
                    f"Original error: {e}"
                ) from e

        # Final cache clear
        self._clear_gpu_cache()
        return np.array(embeddings_list)

    def _clear_gpu_cache(self) -> None:
        """Clear GPU cache for CUDA or MPS devices to prevent memory leaks"""
        if str(self.device).startswith("cuda"):
            torch.cuda.empty_cache()
        elif str(self.device).startswith("mps"):
            torch.mps.empty_cache()
Functions
embed_sequence(sequence)

Extract ESM-1V embedding for a single protein sequence

Parameters:

Name Type Description Default
sequence str

Amino acid sequence string

required

Returns:

Type Description
ndarray

Embedding vector as numpy array

Raises:

Type Description
ValueError

If sequence contains invalid amino acids or is too short

Source code in src/antibody_training_esm/core/embeddings.py
def embed_sequence(self, sequence: str) -> np.ndarray:
    """
    Extract ESM-1V embedding for a single protein sequence

    Args:
        sequence: Amino acid sequence string

    Returns:
        Embedding vector as numpy array

    Raises:
        ValueError: If sequence contains invalid amino acids or is too short
    """
    try:
        # Validate sequence (20 standard amino acids + X for unknown/ambiguous)
        # X is supported by ESM tokenizer for ambiguous residues
        valid_aas = set("ACDEFGHIKLMNPQRSTVWYX")
        sequence = sequence.upper().strip()

        if not all(aa in valid_aas for aa in sequence):
            raise ValueError("Invalid amino acid characters in sequence")

        if len(sequence) < 1:
            raise ValueError("Sequence too short")

        # Tokenize the sequence
        inputs = self.tokenizer(
            sequence,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_length,
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Get embeddings
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            embeddings = outputs.hidden_states[-1]  # (batch, seq_len, hidden_dim)

            # Use attention mask to properly exclude padding and special tokens
            attention_mask = inputs["attention_mask"].unsqueeze(
                -1
            )  # (batch, seq_len, 1)

            # Mask out special tokens (first and last)
            attention_mask[:, 0, :] = 0  # CLS token
            attention_mask[:, -1, :] = 0  # EOS token

            # Masked mean pooling
            masked_embeddings = embeddings * attention_mask
            sum_embeddings = masked_embeddings.sum(dim=1)  # Sum over sequence
            sum_mask = attention_mask.sum(dim=1)  # Count valid tokens

            # Prevent division by zero (NaN embeddings)
            if sum_mask.item() == 0:
                raise ValueError(
                    f"Attention mask is all zeros for sequence (length: {len(sequence)}). "
                    f"Sequence preview: '{sequence[:SEQUENCE_PREVIEW_LENGTH]}...'. "
                    "This typically indicates an empty or invalid sequence after masking."
                )

            mean_embeddings = sum_embeddings / sum_mask  # Average

            result: np.ndarray = mean_embeddings.squeeze(0).cpu().numpy()
            return result

    except Exception as e:
        # Add sequence context to error message (truncate for readability)
        seq_preview = (
            sequence[:SEQUENCE_PREVIEW_LENGTH] + "..."
            if len(sequence) > SEQUENCE_PREVIEW_LENGTH
            else sequence
        )
        logger.error(
            f"Error getting embeddings for sequence (length={len(sequence)}): {seq_preview}"
        )
        raise RuntimeError(
            f"Failed to extract embedding for sequence of length {len(sequence)}: {seq_preview}"
        ) from e
extract_batch_embeddings(sequences)

Extract embeddings for multiple sequences using efficient batching

Parameters:

Name Type Description Default
sequences list[str]

List of amino acid sequence strings

required

Returns:

Type Description
ndarray

Array of embeddings with shape (n_sequences, embedding_dim)

Source code in src/antibody_training_esm/core/embeddings.py
def extract_batch_embeddings(self, sequences: list[str]) -> np.ndarray:
    """
    Extract embeddings for multiple sequences using efficient batching

    Args:
        sequences: List of amino acid sequence strings

    Returns:
        Array of embeddings with shape (n_sequences, embedding_dim)
    """
    embeddings_list = []

    logger.info(
        f"Extracting embeddings for {len(sequences)} sequences with batch_size={self.batch_size}..."
    )

    # Process sequences in batches
    num_batches = (len(sequences) + self.batch_size - 1) // self.batch_size

    for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
        start_idx = batch_idx * self.batch_size
        end_idx = min(start_idx + self.batch_size, len(sequences))
        batch_sequences = sequences[start_idx:end_idx]

        try:
            # Validate and clean sequences
            valid_aas = set("ACDEFGHIKLMNPQRSTVWYX")
            cleaned_sequences: list[str] = []
            invalid_sequences: list[
                tuple[int, str, str]
            ] = []  # (index, sequence, reason)

            for seq_idx, seq in enumerate(batch_sequences):
                seq = seq.upper().strip()
                global_idx = start_idx + seq_idx

                # Check for empty/short sequences
                if len(seq) < 1:
                    invalid_sequences.append(
                        (global_idx, seq, "empty or too short")
                    )
                    continue

                # Check for invalid amino acids
                invalid_chars = [aa for aa in seq if aa not in valid_aas]
                if invalid_chars:
                    reason = f"invalid characters: {set(invalid_chars)}"
                    invalid_sequences.append(
                        (global_idx, seq[:SEQUENCE_PREVIEW_LENGTH], reason)
                    )
                    continue

                cleaned_sequences.append(seq)

            # If any sequences are invalid, fail immediately
            if invalid_sequences:
                error_details = "\n".join(
                    f"  Index {idx}: '{seq}...' ({reason})"
                    for idx, seq, reason in invalid_sequences[:ERROR_PREVIEW_LIMIT]
                )
                total_invalid = len(invalid_sequences)
                raise ValueError(
                    f"Found {total_invalid} invalid sequence(s) in batch {batch_idx}:\n{error_details}"
                    + (
                        f"\n  ... and {total_invalid - ERROR_PREVIEW_LIMIT} more"
                        if total_invalid > ERROR_PREVIEW_LIMIT
                        else ""
                    )
                )

            # Tokenize the batch with padding
            inputs = self.tokenizer(
                cleaned_sequences,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_length,
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            # Get embeddings for the batch
            with torch.no_grad():
                outputs = self.model(**inputs, output_hidden_states=True)
                embeddings = outputs.hidden_states[
                    -1
                ]  # (batch, seq_len, hidden_dim)

                # Use attention mask to properly exclude padding and special tokens
                attention_mask = inputs["attention_mask"].unsqueeze(
                    -1
                )  # (batch, seq_len, 1)

                # Mask out special tokens (first and last)
                attention_mask[:, 0, :] = 0  # CLS token
                attention_mask[:, -1, :] = 0  # EOS token

                # Masked mean pooling
                masked_embeddings = embeddings * attention_mask
                sum_embeddings = masked_embeddings.sum(dim=1)  # Sum over sequence
                sum_mask = attention_mask.sum(dim=1)  # Count valid tokens

                # Prevent division by zero (NaN embeddings)
                # Use clamp to avoid zero divisors (min valid tokens = 1)
                sum_mask_safe = sum_mask.clamp(min=1e-9)
                mean_embeddings = sum_embeddings / sum_mask_safe  # Average

                # Check if any sequences had zero mask (would produce near-zero or invalid embeddings)
                zero_mask_indices = (
                    (sum_mask == 0).any(dim=1).nonzero(as_tuple=True)[0]
                )
                if len(zero_mask_indices) > 0:
                    bad_seqs = [
                        cleaned_sequences[i.item()][:SEQUENCE_PREVIEW_LENGTH]
                        for i in zero_mask_indices[:3]
                    ]
                    raise ValueError(
                        f"Found {len(zero_mask_indices)} sequence(s) with zero attention mask in batch {batch_idx}. "
                        f"Sample sequences: {bad_seqs}. This indicates empty/invalid sequences after masking."
                    )

                # Convert to numpy and add to list
                batch_embeddings = mean_embeddings.cpu().numpy()
                for emb in batch_embeddings:
                    embeddings_list.append(emb)

            # Clear GPU cache periodically to prevent OOM
            if (batch_idx + 1) % GPU_CACHE_CLEAR_INTERVAL == 0:
                self._clear_gpu_cache()

        except Exception as e:
            logger.error(
                f"CRITICAL: Failed to process batch {batch_idx} (sequences {start_idx}-{end_idx}): {e}"
            )
            logger.error(
                f"First sequence in failed batch: {batch_sequences[0][:100]}..."
            )
            raise RuntimeError(
                f"Batch processing failed at batch {batch_idx}. Cannot continue with corrupted embeddings. "
                f"Original error: {e}"
            ) from e

    # Final cache clear
    self._clear_gpu_cache()
    return np.array(embeddings_list)