"""
Common model architectures for examples.
Provides standard architectures used across multiple examples to avoid duplication.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class ConvNet(nn.Module):
"""
Simple CNN for MNIST-sized inputs (28x28).
Architecture:
Conv(1→32) → ReLU → Conv(32→64) → ReLU → MaxPool
→ Dropout → FC(9216→128) → ReLU → Dropout → FC(128→num_classes)
Args:
num_classes: Number of output classes (default: 10)
dropout_rate: Dropout probability (default: 0.2)
input_channels: Number of input channels (default: 1 for grayscale)
"""
[docs]
def __init__(
self,
num_classes: int = 10,
dropout_rate: float = 0.2,
input_channels: int = 1,
):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(dropout_rate)
self.dropout2 = nn.Dropout(dropout_rate)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, num_classes)
[docs]
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return x
class BasicBlock(nn.Module):
"""Basic residual block for ResNet."""
def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, 3, stride=stride, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
[docs]
class ResNet18(nn.Module):
"""
ResNet-18 for CIFAR-10 sized inputs (32x32).
Simplified version adapted for small images.
Args:
num_classes: Number of output classes (default: 10)
input_channels: Number of input channels (default: 3 for RGB)
"""
[docs]
def __init__(self, num_classes: int = 10, input_channels: int = 3):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, 64, 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
# Residual blocks
self.layer1 = self._make_layer(64, 64, 2, stride=1)
self.layer2 = self._make_layer(64, 128, 2, stride=2)
self.layer3 = self._make_layer(128, 256, 2, stride=2)
self.layer4 = self._make_layer(256, 512, 2, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def _make_layer(
self, in_channels: int, out_channels: int, num_blocks: int, stride: int
):
layers = []
# First block may have stride > 1
layers.append(BasicBlock(in_channels, out_channels, stride))
# Remaining blocks
for _ in range(1, num_blocks):
layers.append(BasicBlock(out_channels, out_channels, 1))
return nn.Sequential(*layers)
[docs]
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
[docs]
class MLP(nn.Module):
"""
Multi-layer perceptron for tabular data.
Args:
input_dim: Input feature dimension
hidden_dims: List of hidden layer dimensions
num_classes: Number of output classes
dropout_rate: Dropout probability (default: 0.2)
batch_norm: Whether to use batch normalization (default: True)
"""
[docs]
def __init__(
self,
input_dim: int,
hidden_dims: list[int] = [128, 64],
num_classes: int = 2,
dropout_rate: float = 0.2,
batch_norm: bool = True,
):
super().__init__()
layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
if batch_norm:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_rate))
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, num_classes))
self.model = nn.Sequential(*layers)
[docs]
def forward(self, x):
return self.model(x)
__all__ = ["ConvNet", "ResNet18", "MLP", "BasicBlock"]