Source code for incerto.utils.models

"""
Common model architectures for examples.

Provides standard architectures used across multiple examples to avoid duplication.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class ConvNet(nn.Module): """ Simple CNN for MNIST-sized inputs (28x28). Architecture: Conv(1→32) → ReLU → Conv(32→64) → ReLU → MaxPool → Dropout → FC(9216→128) → ReLU → Dropout → FC(128→num_classes) Args: num_classes: Number of output classes (default: 10) dropout_rate: Dropout probability (default: 0.2) input_channels: Number of input channels (default: 1 for grayscale) """
[docs] def __init__( self, num_classes: int = 10, dropout_rate: float = 0.2, input_channels: int = 1, ): super().__init__() self.conv1 = nn.Conv2d(input_channels, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout(dropout_rate) self.dropout2 = nn.Dropout(dropout_rate) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, num_classes)
[docs] def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) return x
class BasicBlock(nn.Module): """Basic residual block for ResNet.""" def __init__(self, in_channels: int, out_channels: int, stride: int = 1): super().__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, 3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False), nn.BatchNorm2d(out_channels), ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out
[docs] class ResNet18(nn.Module): """ ResNet-18 for CIFAR-10 sized inputs (32x32). Simplified version adapted for small images. Args: num_classes: Number of output classes (default: 10) input_channels: Number of input channels (default: 3 for RGB) """
[docs] def __init__(self, num_classes: int = 10, input_channels: int = 3): super().__init__() self.conv1 = nn.Conv2d(input_channels, 64, 3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) # Residual blocks self.layer1 = self._make_layer(64, 64, 2, stride=1) self.layer2 = self._make_layer(64, 128, 2, stride=2) self.layer3 = self._make_layer(128, 256, 2, stride=2) self.layer4 = self._make_layer(256, 512, 2, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512, num_classes)
def _make_layer( self, in_channels: int, out_channels: int, num_blocks: int, stride: int ): layers = [] # First block may have stride > 1 layers.append(BasicBlock(in_channels, out_channels, stride)) # Remaining blocks for _ in range(1, num_blocks): layers.append(BasicBlock(out_channels, out_channels, 1)) return nn.Sequential(*layers)
[docs] def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x
[docs] class MLP(nn.Module): """ Multi-layer perceptron for tabular data. Args: input_dim: Input feature dimension hidden_dims: List of hidden layer dimensions num_classes: Number of output classes dropout_rate: Dropout probability (default: 0.2) batch_norm: Whether to use batch normalization (default: True) """
[docs] def __init__( self, input_dim: int, hidden_dims: list[int] = [128, 64], num_classes: int = 2, dropout_rate: float = 0.2, batch_norm: bool = True, ): super().__init__() layers = [] prev_dim = input_dim for hidden_dim in hidden_dims: layers.append(nn.Linear(prev_dim, hidden_dim)) if batch_norm: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.ReLU()) layers.append(nn.Dropout(dropout_rate)) prev_dim = hidden_dim layers.append(nn.Linear(prev_dim, num_classes)) self.model = nn.Sequential(*layers)
[docs] def forward(self, x): return self.model(x)
__all__ = ["ConvNet", "ResNet18", "MLP", "BasicBlock"]