CLI prediction with Pydantic validation.
Source code in src/antibody_training_esm/cli/predict.py
| def predict_sequence_cli(
sequence: str, threshold: float, assay_type: AssayType | None, cfg: DictConfig
) -> None:
"""CLI prediction with Pydantic validation."""
config_path = getattr(cfg.classifier, "config_path", None)
# Respect explicit model.device override; fall back to hardware.device
requested_device = getattr(cfg.model, "device", None) or getattr(
getattr(cfg, "hardware", None), "device", None
)
# Instantiate predictor (loading model)
try:
predictor = Predictor(
model_name=cfg.model.name,
classifier_path=cfg.classifier.path,
device=requested_device,
config_path=config_path,
)
except Exception as e:
print(f"Error loading model: {e}")
sys.exit(1)
try:
request = PredictionRequest(
sequence=sequence,
threshold=threshold,
assay_type=assay_type,
)
result = predictor.predict_single(request)
# Print formatted output
print(
f"Sequence: {result.sequence[:SEQUENCE_PREVIEW_LENGTH]}..."
if len(result.sequence) > SEQUENCE_PREVIEW_LENGTH
else f"Sequence: {result.sequence}"
)
print(f"Prediction: {result.prediction}")
print(f"Probability: {result.probability:.2%}")
except ValidationError as e:
print("❌ Validation Error:")
for error in e.errors():
# loc is a tuple, e.g. ('sequence',)
loc = error["loc"][0] if error["loc"] else "root"
print(f" - {loc}: {error['msg']}")
sys.exit(1)
|