Spec 055: Embedding NaN Detection and Validation
Status: Implemented (PR #92, 2026-01-03) Priority: High Complexity: Low Related: PIPELINE-BRITTLENESS.md
SSOT (Implemented)
- Code:
src/ai_psychiatrist/infrastructure/validation.py(validate_embedding(),validate_embedding_matrix()) - Exceptions:
src/ai_psychiatrist/domain/exceptions.py(EmbeddingValidationError) - Wire-up:
src/ai_psychiatrist/services/reference_store.py,src/ai_psychiatrist/services/embedding.py,scripts/generate_embeddings.py - Tests:
tests/unit/infrastructure/test_validation.py,tests/unit/services/test_reference_store.py,tests/unit/services/test_embedding.py,tests/unit/scripts/test_generate_embeddings_fail_fast.py
Problem Statement
NaN (Not a Number) values in embedding vectors propagate silently through the pipeline:
- Source: Embedding backend returns NaN (rare but possible with malformed input)
- Propagation: L2 normalization with NaN → NaN persists
- Corruption: Cosine similarity with NaN → NaN similarity scores
- Result: Reference ranking becomes meaningless
This is a silent corruption that produces unpredictable results without any error.
Previous Behavior (Fixed)
# src/ai_psychiatrist/services/reference_store.py - L2 normalization
def _l2_normalize(embedding: list[float]) -> list[float]:
arr = np.array(embedding, dtype=np.float32)
norm = float(np.linalg.norm(arr))
if norm > 0:
arr = arr / norm
return arr.tolist()
If emb contains NaN:
- np.linalg.norm(emb) returns NaN
- norm > 0 is false, so the NaNs remain in the embedding
- No error is raised; NaNs propagate into the reference matrix
# src/ai_psychiatrist/services/embedding.py - similarity computation
similarities = matrix @ query_vec # NaN propagates
similarities = (1.0 + similarities) / 2.0
Implemented Solution
Add NaN detection at all embedding generation and loading points.
Implementation
Core Validation Function
# New: src/ai_psychiatrist/infrastructure/validation.py
import numpy as np
from ai_psychiatrist.domain.exceptions import EmbeddingValidationError
def validate_embedding(
embedding: np.ndarray,
context: str = "embedding",
*,
check_nan: bool = True,
check_inf: bool = True,
check_zero: bool = True,
) -> np.ndarray:
"""Validate embedding vector for common corruption patterns.
Args:
embedding: Vector to validate
context: Description for error messages (e.g., "query embedding for participant 300")
check_nan: Raise if NaN detected
check_inf: Raise if Inf detected
check_zero: Raise if all-zero vector detected
Returns:
The validated embedding (unchanged)
Raises:
EmbeddingValidationError: If validation fails
"""
if check_nan and np.isnan(embedding).any():
nan_count = np.isnan(embedding).sum()
nan_positions = np.where(np.isnan(embedding))[0][:5] # First 5 positions
raise EmbeddingValidationError(
f"NaN detected in {context}: {nan_count} NaN values at positions {nan_positions.tolist()}"
)
if check_inf and np.isinf(embedding).any():
inf_count = np.isinf(embedding).sum()
raise EmbeddingValidationError(
f"Inf detected in {context}: {inf_count} Inf values"
)
if check_zero and np.allclose(embedding, 0):
raise EmbeddingValidationError(
f"All-zero vector in {context}: L2 norm is 0, cosine similarity undefined"
)
return embedding
def validate_embedding_matrix(
matrix: np.ndarray,
context: str = "embedding matrix",
) -> np.ndarray:
"""Validate entire embedding matrix.
Args:
matrix: 2D array of shape (n_samples, n_dims)
context: Description for error messages
Returns:
The validated matrix (unchanged)
Raises:
EmbeddingValidationError: If validation fails
"""
if matrix.ndim != 2:
raise EmbeddingValidationError(
f"Expected 2D matrix in {context}, got shape {matrix.shape}"
)
nan_mask = np.isnan(matrix)
if nan_mask.any():
nan_rows = np.where(nan_mask.any(axis=1))[0]
raise EmbeddingValidationError(
f"NaN detected in {context}: {len(nan_rows)} rows contain NaN "
f"(first few: {nan_rows[:5].tolist()})"
)
inf_mask = np.isinf(matrix)
if inf_mask.any():
inf_rows = np.where(inf_mask.any(axis=1))[0]
raise EmbeddingValidationError(
f"Inf detected in {context}: {len(inf_rows)} rows contain Inf"
)
zero_rows = np.where(~matrix.any(axis=1))[0]
if len(zero_rows) > 0:
raise EmbeddingValidationError(
f"All-zero rows in {context}: {len(zero_rows)} rows "
f"(first few: {zero_rows[:5].tolist()})"
)
return matrix
New Exception Type
# domain/exceptions.py - add new exception
class EmbeddingValidationError(EmbeddingError):
\"\"\"Raised when embedding validation fails (NaN/Inf/zero).\"\"\"
Integration Points
1. Query Embedding Generation
# src/ai_psychiatrist/services/embedding.py - EmbeddingService.embed_text()
async def embed_text(self, text: str) -> tuple[float, ...]:
\"\"\"Generate embedding for text.\"\"\"
response = await self._llm_client.embed(...)
embedding = response.embedding
if not embedding:
return () # existing behavior for too-short text
vector = np.array(embedding, dtype=np.float32)
# NEW: Validate before returning
from ai_psychiatrist.infrastructure.validation import validate_embedding
validate_embedding(
vector,
context="query embedding (no text logged)",
)
return tuple(vector.tolist())
2. Reference Store Loading
# src/ai_psychiatrist/services/reference_store.py - ReferenceStore._l2_normalize()
def _l2_normalize(embedding: list[float]) -> list[float]:
arr = np.array(embedding, dtype=np.float32)
from ai_psychiatrist.infrastructure.validation import validate_embedding
validate_embedding(arr, context="reference embedding pre-normalize")
norm = float(np.linalg.norm(arr))
if norm > 0:
arr = arr / norm
validate_embedding(arr, context="reference embedding post-normalize", check_zero=True)
return arr.tolist()
3. Similarity Computation
# src/ai_psychiatrist/services/embedding.py - EmbeddingService._compute_similarities()
from ai_psychiatrist.infrastructure.validation import validate_embedding
query_vec = np.array(query_embedding, dtype=np.float32)
validate_embedding(query_vec, context="query embedding pre-similarity")
similarities = matrix @ query_vec
if not np.isfinite(similarities).all():
raise EmbeddingValidationError("Non-finite similarity scores (NaN/Inf)")
4. Reference Artifact Generation (Recommended)
Also validate during scripts/generate_embeddings.py so bad artifacts fail fast (before writing .npz/.json):
- after each generate_embedding(...), validate the returned vector is finite and non-zero
- in --allow-partial mode, skip that chunk and record it in the .partial.json manifest
Testing
# tests/unit/infrastructure/test_validation.py
import numpy as np
import pytest
from ai_psychiatrist.infrastructure.validation import (
validate_embedding,
validate_embedding_matrix,
)
from ai_psychiatrist.domain.exceptions import EmbeddingValidationError
class TestValidateEmbedding:
def test_valid_embedding_passes(self):
emb = np.array([0.1, 0.2, 0.3, 0.4])
result = validate_embedding(emb, "test")
assert np.array_equal(result, emb)
def test_nan_raises(self):
emb = np.array([0.1, np.nan, 0.3])
with pytest.raises(EmbeddingValidationError) as exc:
validate_embedding(emb, "test embedding")
assert "NaN" in str(exc.value)
assert "test embedding" in str(exc.value)
def test_inf_raises(self):
emb = np.array([0.1, np.inf, 0.3])
with pytest.raises(EmbeddingValidationError) as exc:
validate_embedding(emb, "test")
assert "Inf" in str(exc.value)
def test_negative_inf_raises(self):
emb = np.array([0.1, -np.inf, 0.3])
with pytest.raises(EmbeddingValidationError) as exc:
validate_embedding(emb, "test")
assert "Inf" in str(exc.value)
def test_zero_vector_raises(self):
emb = np.array([0.0, 0.0, 0.0])
with pytest.raises(EmbeddingValidationError) as exc:
validate_embedding(emb, "test")
assert "zero" in str(exc.value).lower()
def test_near_zero_passes(self):
emb = np.array([1e-10, 1e-10, 1e-10])
result = validate_embedding(emb, "test") # Should not raise
assert result is not None
def test_checks_can_be_disabled(self):
emb = np.array([np.nan, np.inf, 0.0])
result = validate_embedding(
emb, "test",
check_nan=False,
check_inf=False,
check_zero=False,
)
assert np.isnan(result[0])
class TestValidateEmbeddingMatrix:
def test_valid_matrix_passes(self):
matrix = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
result = validate_embedding_matrix(matrix, "test matrix")
assert np.array_equal(result, matrix)
def test_nan_row_identified(self):
matrix = np.array([
[0.1, 0.2],
[np.nan, 0.4], # Row 1 has NaN
[0.5, 0.6],
])
with pytest.raises(EmbeddingValidationError) as exc:
validate_embedding_matrix(matrix, "test")
assert "1" in str(exc.value) # Row index mentioned
def test_multiple_nan_rows_reported(self):
matrix = np.array([
[np.nan, 0.2], # Row 0
[0.3, 0.4],
[0.5, np.nan], # Row 2
])
with pytest.raises(EmbeddingValidationError) as exc:
validate_embedding_matrix(matrix, "test")
assert "2 rows" in str(exc.value)
def test_zero_row_detected(self):
matrix = np.array([
[0.1, 0.2],
[0.0, 0.0], # Zero row
[0.5, 0.6],
])
with pytest.raises(EmbeddingValidationError) as exc:
validate_embedding_matrix(matrix, "test")
assert "zero" in str(exc.value).lower()
def test_1d_array_rejected(self):
vector = np.array([0.1, 0.2, 0.3])
with pytest.raises(EmbeddingValidationError) as exc:
validate_embedding_matrix(vector, "test")
assert "2D" in str(exc.value)
Performance Considerations
| Operation | Overhead | Notes |
|---|---|---|
np.isnan(embedding).any() |
~1μs for 4096-dim | Negligible |
np.isnan(matrix).any() |
~1ms for 10K×4096 | Done once at load time |
| Per-query validation | ~2μs | Negligible per participant |
Conclusion: Overhead is negligible. Always validate.
Failure Modes After Implementation
| Scenario | Before | After |
|---|---|---|
| NaN in query embedding | Silent corruption | EmbeddingValidationError raised |
| NaN in reference matrix | Silent corruption | Fails at load time |
| Zero vector after normalization | Undefined similarity | EmbeddingValidationError raised |
| Inf from numerical overflow | Silent corruption | EmbeddingValidationError raised |
Rollout Plan
- Phase 1: Implement validation functions and exception
- Phase 2: Add validation at query embedding generation
- Phase 3: Add validation at reference matrix loading
- Phase 4: Add validation at similarity computation output
All phases can be deployed together - this is a pure strictness improvement.
Success Criteria
- All NaN/Inf/zero embeddings detected at point of origin
- Clear error messages with privacy-safe context (participant/stage + hashes/lengths; no transcript text)
- Test coverage for all edge cases
- <1ms additional latency per participant
Future Enhancements
- Automatic retry: If embedding fails validation, retry with cleaned input
- Metric tracking: Log validation failure rates for monitoring
- Partial matrix handling: Skip invalid rows instead of failing entire load (opt-in)