Source code for incerto.llm.sequence

"""
Sequence-level uncertainty quantification for LLMs.

These methods aggregate token-level uncertainties to produce
a single uncertainty score for an entire generated sequence.
"""

from __future__ import annotations
import torch
import torch.nn.functional as F
from .token import TokenEntropy, SurprisalScore


[docs] class SequenceProbability: """ Joint probability of the entire sequence. Computed as the product of individual token probabilities. Lower values indicate higher uncertainty. """
[docs] @staticmethod def compute( logits: torch.Tensor, token_ids: torch.Tensor, dim: int = -1, ) -> torch.Tensor: """ Compute sequence probability. Args: logits: Token logits of shape (batch, seq_len, vocab_size) token_ids: Generated token IDs of shape (batch, seq_len) dim: Dimension to compute softmax over (default: -1) Returns: Sequence probabilities of shape (batch,) """ log_probs = F.log_softmax(logits, dim=dim) # Gather log probs for actual tokens token_log_probs = torch.gather( log_probs, dim=-1, index=token_ids.unsqueeze(-1) ).squeeze(-1) # Sum log probs = product of probs seq_log_prob = token_log_probs.sum(dim=1) seq_prob = torch.exp(seq_log_prob) return seq_prob
class AverageLogProb: """ Mean log-probability across the sequence. Average negative log-likelihood per token. Higher (less negative) values indicate lower uncertainty. """ @staticmethod def compute( logits: torch.Tensor, token_ids: torch.Tensor, mask: torch.Tensor | None = None, dim: int = -1, ) -> torch.Tensor: """ Compute average log probability. Args: logits: Token logits of shape (batch, seq_len, vocab_size) token_ids: Generated token IDs of shape (batch, seq_len) mask: Optional mask for padding of shape (batch, seq_len) dim: Dimension to compute softmax over (default: -1) Returns: Average log probabilities of shape (batch,) """ log_probs = F.log_softmax(logits, dim=dim) # Gather log probs for actual tokens token_log_probs = torch.gather( log_probs, dim=-1, index=token_ids.unsqueeze(-1) ).squeeze(-1) if mask is not None: # Apply mask and normalize by actual length token_log_probs = token_log_probs * mask seq_lengths = mask.sum(dim=1) avg_log_prob = token_log_probs.sum(dim=1) / seq_lengths else: avg_log_prob = token_log_probs.mean(dim=1) return avg_log_prob class NormalizedSequenceProb: """ Length-normalized sequence probability. Accounts for the fact that longer sequences naturally have lower probabilities. Common in beam search and generation. """ @staticmethod def compute( logits: torch.Tensor, token_ids: torch.Tensor, length_penalty: float = 1.0, mask: torch.Tensor | None = None, dim: int = -1, ) -> torch.Tensor: """ Compute length-normalized sequence probability. Args: logits: Token logits of shape (batch, seq_len, vocab_size) token_ids: Generated token IDs of shape (batch, seq_len) length_penalty: Length normalization factor (higher = more penalty) mask: Optional mask for padding of shape (batch, seq_len) dim: Dimension to compute softmax over (default: -1) Returns: Normalized probabilities of shape (batch,) """ avg_log_prob = AverageLogProb.compute(logits, token_ids, mask, dim) if mask is not None: seq_lengths = mask.sum(dim=1) else: seq_lengths = torch.tensor( token_ids.size(1), dtype=torch.float, device=token_ids.device ).expand(token_ids.size(0)) # Length normalization: score / length^penalty normalized_score = avg_log_prob / (seq_lengths**length_penalty) return torch.exp(normalized_score)
[docs] class SequenceEntropy: """ Aggregated entropy over the sequence. Can use mean, sum, or max aggregation of token-level entropies. """
[docs] @staticmethod def compute( logits: torch.Tensor, mask: torch.Tensor | None = None, aggregation: str = "mean", dim: int = -1, ) -> torch.Tensor: """ Compute sequence entropy. Args: logits: Token logits of shape (batch, seq_len, vocab_size) mask: Optional mask for padding of shape (batch, seq_len) aggregation: How to aggregate ("mean", "sum", "max") dim: Dimension to compute entropy over (default: -1) Returns: Sequence entropy of shape (batch,) """ # Compute token-level entropy token_entropy = TokenEntropy.compute(logits, dim=dim) if mask is not None: token_entropy = token_entropy * mask # Aggregate if aggregation == "mean": if mask is not None: seq_lengths = mask.sum(dim=1) result = token_entropy.sum(dim=1) / seq_lengths else: result = token_entropy.mean(dim=1) elif aggregation == "sum": result = token_entropy.sum(dim=1) elif aggregation == "max": if mask is not None: # Set masked positions to -inf before max token_entropy = token_entropy.masked_fill(~mask.bool(), float("-inf")) result, _ = token_entropy.max(dim=1) else: raise ValueError(f"Unknown aggregation: {aggregation}") return result
class SequencePerplexity: """ Perplexity of the entire sequence. Exponential of average log-probability. Standard metric for language model quality. """ @staticmethod def compute( logits: torch.Tensor, token_ids: torch.Tensor, mask: torch.Tensor | None = None, dim: int = -1, ) -> torch.Tensor: """ Compute sequence perplexity. Args: logits: Token logits of shape (batch, seq_len, vocab_size) token_ids: Generated token IDs of shape (batch, seq_len) mask: Optional mask for padding of shape (batch, seq_len) dim: Dimension to compute softmax over (default: -1) Returns: Perplexity values of shape (batch,) """ avg_log_prob = AverageLogProb.compute(logits, token_ids, mask, dim) # Perplexity = exp(-log_prob) = exp(negative log likelihood) perplexity = torch.exp(-avg_log_prob) return perplexity class VarianceOfTokenProbs: """ Variance of token probabilities across the sequence. Measures how consistent the model's confidence is throughout the generation. Higher variance indicates varying uncertainty. """ @staticmethod def compute( logits: torch.Tensor, mask: torch.Tensor | None = None, dim: int = -1, ) -> torch.Tensor: """ Compute variance of max token probabilities. Args: logits: Token logits of shape (batch, seq_len, vocab_size) mask: Optional mask for padding of shape (batch, seq_len) dim: Dimension to compute softmax over (default: -1) Returns: Variance of shape (batch,) """ probs = F.softmax(logits, dim=dim) max_probs, _ = probs.max(dim=dim) if mask is not None: max_probs = max_probs * mask seq_lengths = mask.sum(dim=1) mean = max_probs.sum(dim=1) / seq_lengths # Variance diff = (max_probs - mean.unsqueeze(1)) ** 2 diff = diff * mask variance = diff.sum(dim=1) / seq_lengths else: variance = max_probs.var(dim=1) return variance