Source code for incerto.utils.training

"""
Training utilities for examples.

Common training loops and helpers to reduce boilerplate in examples.
"""

import torch
import torch.nn as nn
import numpy as np
import random
from typing import Optional
from tqdm import tqdm


[docs] def seed_everything(seed: int = 42): """ Set random seeds for reproducibility. Args: seed: Random seed value """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Make CUDA operations deterministic torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
[docs] def train_epoch( model: nn.Module, train_loader: torch.utils.data.DataLoader, criterion: nn.Module, optimizer: torch.optim.Optimizer, device: torch.device, desc: str = "Training", ) -> dict: """ Train model for one epoch. Args: model: Model to train train_loader: Training data loader criterion: Loss function optimizer: Optimizer device: Device to train on desc: Description for progress bar Returns: Dictionary with 'loss' and 'accuracy' """ model.train() total_loss = 0.0 correct = 0 total = 0 pbar = tqdm(train_loader, desc=desc, leave=False) for inputs, targets in pbar: inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() total_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() # Update progress bar pbar.set_postfix( {"loss": f"{loss.item():.4f}", "acc": f"{100.*correct/total:.2f}%"} ) return { "loss": total_loss / len(train_loader), "accuracy": 100.0 * correct / total, }
def evaluate( model: nn.Module, data_loader: torch.utils.data.DataLoader, criterion: nn.Module, device: torch.device, desc: str = "Evaluating", ) -> dict: """ Evaluate model on dataset. Args: model: Model to evaluate data_loader: Data loader criterion: Loss function device: Device to evaluate on desc: Description for progress bar Returns: Dictionary with 'loss', 'accuracy', 'logits', 'targets' """ model.eval() total_loss = 0.0 correct = 0 total = 0 all_logits = [] all_targets = [] with torch.no_grad(): pbar = tqdm(data_loader, desc=desc, leave=False) for inputs, targets in pbar: inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) total_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() all_logits.append(outputs.cpu()) all_targets.append(targets.cpu()) pbar.set_postfix( {"loss": f"{loss.item():.4f}", "acc": f"{100.*correct/total:.2f}%"} ) all_logits = torch.cat(all_logits) all_targets = torch.cat(all_targets) return { "loss": total_loss / len(data_loader), "accuracy": 100.0 * correct / total, "logits": all_logits, "targets": all_targets, } class EarlyStopping: """ Early stopping to stop training when validation loss stops improving. Args: patience: Number of epochs to wait for improvement min_delta: Minimum change to qualify as improvement mode: 'min' for loss, 'max' for accuracy """ def __init__( self, patience: int = 10, min_delta: float = 0.0, mode: str = "min", ): self.patience = patience self.min_delta = min_delta self.mode = mode self.counter = 0 self.best_score = None self.early_stop = False def __call__(self, score: float) -> bool: """ Check if should stop training. Args: score: Current validation score Returns: True if should stop, False otherwise """ if self.best_score is None: self.best_score = score return False if self.mode == "min": improved = score < (self.best_score - self.min_delta) else: improved = score > (self.best_score + self.min_delta) if improved: self.best_score = score self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True return self.early_stop __all__ = [ "seed_everything", "train_epoch", "evaluate", "EarlyStopping", ]