Source code for eyefeatures.deep.models

from collections.abc import Callable

import pytorch_lightning as pl
import torch
import torchmetrics
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import GCNConv, GINConv, global_mean_pool
from tqdm import tqdm


[docs] class VGGBlock(nn.Module): """A VGG-style convolutional block consisting of a convolution layer, batch normalization, and ReLU activation. Args: in_channels: (int) Number of input channels. out_channels: (int) Number of output channels. kernel_size: (int, optional) Size of the convolution kernel. Default is 3. padding: (int, optional) Padding for the convolution layer. Default is 1. stride: (int, optional) Stride for the convolution layer. Default is 1. Returns: Output tensor after applying the convolution, batch normalization, and ReLU activation. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1, stride: int = 1, ): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, images): x = self.conv(images) x = self.bn(x) x = self.relu(x) return x
[docs] class ResnetBlock(nn.Module): """A ResNet-style residual block consisting of two convolution layers with skip connections. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. kernel_size (int, optional): Size of the convolution kernel. Default is 3. padding (int, optional): Padding for the convolution layers. Default is 1. stride (int, optional): Stride for the convolution layers. Default is 1. Returns: Output tensor after applying residual connection and ReLU activation. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1, stride: int = 1, ): super().__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, ) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=kernel_size, padding=padding ) self.bn2 = nn.BatchNorm2d(out_channels) # Shortcut when channels or spatial size change (standard ResNet) if in_channels != out_channels or stride != 1: self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=stride) else: self.shortcut = nn.Identity() def forward(self, images): identity = self.shortcut(images) out = self.conv1(images) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu(out) return out
[docs] class InceptionBlock(nn.Module): """An Inception-style block consisting of multiple convolution branches operating on different scales. Args: in_channels: (int) Number of input channels. ch1x1: (int) Number of output channels for the 1x1 convolution branch. ch3x3_reduce: (int) Number of output channels for the 1x1 convolution before the 3x3 convolution. ch3x3: (int) Number of output channels for the 3x3 convolution branch. ch5x5_reduce: (int) Number of output channels for the 1x1 convolution before the 5x5 convolution. ch5x5: (int) Number of output channels for the 5x5 convolution branch. pool_proj: (int) Number of output channels for the 1x1 convolution after the max pooling branch. Returns: Concatenated output tensor from all branches. """ def __init__( self, in_channels: int, ch1x1: int, ch3x3_reduce: int, ch3x3: int, ch5x5_reduce: int, ch5x5: int, pool_proj: int, ): super().__init__() self.branch1x1 = nn.Conv2d(in_channels, ch1x1, kernel_size=1) self.branch3x3 = nn.Sequential( nn.Conv2d(in_channels, ch3x3_reduce, kernel_size=1), nn.Conv2d(ch3x3_reduce, ch3x3, kernel_size=3, padding=1), ) self.branch5x5 = nn.Sequential( nn.Conv2d(in_channels, ch5x5_reduce, kernel_size=1), nn.Conv2d(ch5x5_reduce, ch5x5, kernel_size=5, padding=2), ) self.branch_pool = nn.Sequential( nn.MaxPool2d(kernel_size=3, stride=1, padding=1), nn.Conv2d(in_channels, pool_proj, kernel_size=1), ) def forward(self, images): branch1x1 = self.branch1x1(images) branch3x3 = self.branch3x3(images) branch5x5 = self.branch5x5(images) branch_pool = self.branch_pool(images) outputs = [branch1x1, branch3x3, branch5x5, branch_pool] return torch.cat(outputs, 1)
[docs] class DSCBlock(nn.Module): """Depthwise separable convolution, which consists of a depthwise convolution followed by a pointwise convolution. Args: in_channels: (int) Number of input channels. out_channels: (int) Number of output channels. kernel_size: (int, optional) Size of the convolution kernel. Default is 3. stride: (int, optional) Stride for the convolution layers. Default is 1. padding: (int, optional) Padding for the convolution layers. Default is 1. Returns: Output tensor after applying depthwise and pointwise convolutions, batch normalization, and ReLU activation. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): super().__init__() self.depthwise = nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_channels, bias=False, ) self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, images): x = self.depthwise(images) x = self.pointwise(x) x = self.bn(x) x = self.relu(x) return x
_blocks = { "VGG_block": VGGBlock, "Resnet_block": ResnetBlock, "Inception_block": InceptionBlock, "DSC_block": DSCBlock, }
[docs] def create_simple_CNN( config: dict[int, dict[str, str | dict]], in_channels: int, shape: tuple[int, int] = None, ): """Creates a simple CNN based on the provided configuration. Args: config: (Dict) Configuration dictionary where each key represents a layer/block and its corresponding parameters. Supported types: - Block types: 'VGG_block', 'Resnet_block', 'DSC_block', 'Inception_block' - Pooling: 'MaxPool2d' (with params: kernel_size, stride, padding) in_channels: (int) Number of input channels for the first layer. shape: (Tuple[int, int], optional) Input shape for the CNN. If provided, checks that the final output shape is valid. Returns: (nn.Sequential, Tuple[int, int]) or nn.Sequential cnn: (nn.Sequential) Sequential CNN model. shape: (Tuple[int, int], optional) Final output shape, if provided. """ modules = [] for idx, block_config in tqdm(config.items()): block_type = block_config["type"] params = block_config.get("params", {}) # Handle MaxPool2d separately (doesn't change channel count) if block_type == "MaxPool2d": kernel_size = params.get("kernel_size", 2) stride = params.get("stride", kernel_size) padding = params.get("padding", 0) module = nn.MaxPool2d( kernel_size=kernel_size, stride=stride, padding=padding ) if shape is not None: shape = ( (shape[0] + 2 * padding - kernel_size) // stride + 1, (shape[1] + 2 * padding - kernel_size) // stride + 1, ) if shape[0] < 1 or shape[1] < 1: raise ValueError( """Your CNN backbone is too large for the input shape! Increase resolution or consider reducing the number of layers/removing stride/adding padding.""" ) elif block_type != "Inception_block": module = _blocks[block_type](in_channels=in_channels, **params) in_channels = params.get("out_channels", in_channels) if shape is not None: kernel_size = params.get("kernel_size", 3) padding = params.get("padding", 1) stride = params.get("stride", 1) shape = ( (shape[0] - kernel_size + 2 * padding) // stride + 1, (shape[1] - kernel_size + 2 * padding) // stride + 1, ) if shape[0] < 1 or shape[1] < 1: raise ValueError( """Your CNN backbone is too large for the input shape! Increase resolution or consider reducing the number of layers/removing stride/adding padding.""" ) else: module = _blocks[block_type](in_channels=in_channels, **params) in_channels = sum( [params.get(key) for key in ["ch1x1", "ch3x3", "ch5x5", "pool_proj"]] ) modules.append(module) cnn = nn.Sequential(*modules) if shape is not None: return cnn, shape return cnn
[docs] class SimpleRNN(nn.Module): """A simple recurrent neural network (RNN) module that supports RNN, LSTM, and GRU architectures. Args: rnn_type: (str) Type of RNN ('RNN', 'LSTM', or 'GRU'). input_size: (int) Number of input features. hidden_size: (int) Number of hidden units. num_layers: (int, optional) Number of RNN layers. Default is 1. bidirectional: (bool, optional) Whether the RNN is bidirectional. Default is False. pre_rnn_linear_size: (int, optional) Size of the optional linear layer before the RNN. Returns: Output tensor after the RNN and fully connected layer. """ def __init__( self, rnn_type, input_size, hidden_size, num_layers=1, bidirectional=False, pre_rnn_linear_size=None, ): super().__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.bidirectional = bidirectional # Optional linear layer before the RNN if pre_rnn_linear_size is not None: self.pre_rnn_linear = nn.Linear(input_size, pre_rnn_linear_size) self.input_size = pre_rnn_linear_size # for the RNN layer else: self.pre_rnn_linear = None self.input_size = input_size # Original input size # Set the RNN layer based on the rnn_type parameter if rnn_type == "RNN": self.rnn = nn.RNN( self.input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional, ) elif rnn_type == "LSTM": self.rnn = nn.LSTM( self.input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional, ) elif rnn_type == "GRU": self.rnn = nn.GRU( self.input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional, ) else: raise ValueError( f"Unsupported rnn_type: {rnn_type}. Choose from 'RNN', 'LSTM', 'GRU'." ) def forward(self, sequences, lengths, return_all=False): # Apply the optional linear layer before the RNN x = sequences if self.pre_rnn_linear is not None: x = self.pre_rnn_linear(x) # Pack the padded sequence packed_x = nn.utils.rnn.pack_padded_sequence( x, lengths, batch_first=True, enforce_sorted=False ) # Forward pass through the RNN packed_out, hidden = self.rnn(packed_x) # Unpack the sequence out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) if return_all: return out, hidden # If the RNN is bidirectional, concatenate the hidden states # from both directions if self.bidirectional: if isinstance( hidden, tuple ): # For LSTM, hidden is a tuple (hidden_state, cell_state) hidden = torch.cat((hidden[0][-2], hidden[0][-1]), dim=1) else: # For RNN and GRU hidden = torch.cat((hidden[-2], hidden[-1]), dim=1) else: if isinstance(hidden, tuple): # For LSTM hidden = hidden[0][-1] else: # For RNN and GRU hidden = hidden[-1] # Pass the last hidden state through the fully connected layer return hidden
[docs] class VitNet(nn.Module): """Parent class for a vision-and-text network that fuses CNN and RNN-based representations using concatenation or addition. Args: CNN: (nn.Module) CNN backbone for processing image data. RNN: (nn.Module) RNN backbone for processing sequence data. fusion_mode: (str, optional) Fusion mode ('concat' or 'add'). Default is 'concat'. activation: (nn.Module, optional) Activation function applied after fusion. Default is None. embed_dim: (int, optional) Embedding dimension for the projected features. Default is 128. """ def __init__(self, CNN, RNN, fusion_mode="concat", activation=None, embed_dim=32): super().__init__() self.CNN = CNN self.RNN = RNN self.fusion_mode = fusion_mode self.activation = activation self.image_proj = nn.LazyLinear(embed_dim) self.seq_proj = nn.LazyLinear(embed_dim) self.rnn_proj = nn.LazyLinear(embed_dim) self.flat = nn.Flatten() def forward(self, images, sequences, lengths): x_1 = self.image_proj(self.flat(self.CNN(images))) projected_sequences = self.seq_proj(sequences) # SimpleRNN.forward takes (sequences, lengths) x_2_raw = self.RNN(projected_sequences, lengths) x_2 = self.rnn_proj(x_2_raw) if self.fusion_mode == "add": x = x_1 + x_2 elif self.fusion_mode == "concat": x = torch.cat([x_1, x_2], dim=1) else: raise NotImplementedError( f"Fusion mode {self.fusion_mode} is not implemented." ) if self.activation is not None: x = self.activation(x) return x
[docs] class VitNetWithCrossAttention(VitNet): """Child class extending VitNet by adding cross-attention between image and sequence representations. Args: CNN: (nn.Module) CNN backbone for processing image data. RNN: (nn.Module) RNN backbone for processing sequence data. input_dim: (int) Input dimension for the RNN. projected_dim: (int) Dimension of the projected sequence representation. cross_attention_fusion_mode: (str, optional) Fusion mode after cross-attention ('concat' or 'add'). Default is 'concat'. activation: (nn.Module, optional) Activation function applied after fusion. Default is None. embed_dim: (int, optional) Embedding dimension for the projected features. Default is 128. return_attention_weights: (bool, optional) Whether to return attention weights. Default is False. """ def __init__( self, CNN, RNN, cross_attention_fusion_mode="concat", activation=None, embed_dim=128, return_attention_weights=False, ): super().__init__( CNN, RNN, fusion_mode="concat", activation=activation, embed_dim=embed_dim, ) self.cross_attention_layer = nn.MultiheadAttention(embed_dim, num_heads=8) self.cross_attention_fusion_mode = cross_attention_fusion_mode self.return_attention_weights = return_attention_weights self.padding_idx = RNN.padding_idx def forward(self, images, sequences, lengths): # Get base fusion output x_base = super().forward(images, sequences, lengths) # x_img is projected image features from parent x_img = self.image_proj(self.flat(self.CNN(images))) projected_sequences = self.seq_proj(sequences) # We need all hidden states for cross attention all_hidden, _ = self.RNN(projected_sequences, lengths, return_all=True) query = x_img.unsqueeze(0) # (1, batch_size, embed_dim) # Project all hidden states to embed_dim for attention key = self.rnn_proj(all_hidden).transpose(0, 1) value = key key_padding_mask = sequences[:, :, 0] == self.padding_idx context_vector, attention_weights = self.cross_attention_layer( query, key, value, key_padding_mask=key_padding_mask ) output = context_vector.squeeze(0) if self.cross_attention_fusion_mode == "concat": x = torch.cat([x_base, output], dim=1) else: x = output if self.activation is not None: x = self.activation(x) if self.return_attention_weights: return x, attention_weights return x
[docs] class GCN(torch.nn.Module): """A graph convolutional network (GCN) with optional learnable node embeddings. Args: num_nodes: (int) Number of nodes in the graph. feature_dim: (int) Dimensionality of node features. embedding_dim: (int) Dimensionality of the learnable node embeddings. layer_sizes: (List[int]) List of hidden layer sizes for each GCN layer. out_channels: (int) Number of output channels. use_embeddings: (bool, optional) Whether to use learnable embeddings. Default is True. Returns: Output tensor after graph convolutions and global mean pooling. """ def __init__( self, num_nodes, feature_dim, embedding_dim, layer_sizes, out_channels, use_embeddings=True, ): super().__init__() self.use_embeddings = use_embeddings if use_embeddings: self.embeddings = torch.nn.Embedding( num_nodes, embedding_dim ) # Learnable embeddings in_channels = feature_dim + (embedding_dim if use_embeddings else 0) self.convs = torch.nn.ModuleList() # List of GCN layers for hidden_channels in layer_sizes: self.convs.append(GCNConv(in_channels, hidden_channels)) in_channels = hidden_channels # Update in_channels for next layer def forward(self, graphs): if self.use_embeddings: # Combine node features and learnable embeddings node_embeddings = self.embeddings.weight[ graphs.mapping ] # Retrieve embeddings for each node x = torch.cat([graphs.x, node_embeddings], dim=1) else: # Use only node features x = graphs.x # Apply each GCN layer with optional edge weights for conv in self.convs: x = conv(x, graphs.edge_index, edge_weight=graphs.edge_attr) x = F.relu(x) # Global pooling (mean pooling) x = global_mean_pool(x, graphs.batch) return x
[docs] class GIN(torch.nn.Module): """A graph isomorphism network (GIN) with shared embeddings for nodes across graphs. Args: num_common_nodes: (int) Number of shared nodes across graphs. feature_dim: (int) Dimensionality of node features. embedding_dim: (int) Dimensionality of the shared node embeddings. layer_sizes: (List[int]) List of hidden layer sizes for each GIN layer. out_channels: (int) Number of output channels. Returns: Output tensor after GIN layers and global mean pooling. """ def __init__( self, num_common_nodes, feature_dim, embedding_dim, layer_sizes, out_channels ): super().__init__() # Shared embeddings across graphs self.embeddings = torch.nn.Embedding(num_common_nodes, embedding_dim) # Determine the input dimension for the first GIN layer in_channels = feature_dim + embedding_dim self.convs = torch.nn.ModuleList() # List of GIN layers for hidden_channels in layer_sizes: mlp = torch.nn.Sequential( torch.nn.Linear(in_channels, hidden_channels), torch.nn.ReLU(), torch.nn.Linear(hidden_channels, hidden_channels), ) self.convs.append(GINConv(mlp)) in_channels = hidden_channels # Update in_channels for next layer def forward(self, graphs): # Retrieve the shared embeddings using the common index shared_embeddings = self.embeddings(graphs.common_index) # Concatenate node features with the shared embeddings x = torch.cat([graphs.x, shared_embeddings], dim=1) # Apply each GIN layer for conv in self.convs: x = conv(x, graphs.edge_index) x = F.relu(x) # Global pooling (mean pooling) x = global_mean_pool(x, graphs.batch) return x
class BaseModel(pl.LightningModule): """A base PyTorch Lightning module for neural networks with optional custom optimizer and scheduler. Args: backbone: (Union[nn.ModuleList, nn.Module]) The feature extraction backbone model. output_size: (int) Size of the final output layer. hidden_layers: (Tuple, optional) Tuple of hidden layer sizes. Default is empty. activation: (nn.Module, optional) Activation function to apply between layers. Default is ReLU. learning_rate: (float, optional) Learning rate for the optimizer. Default is 1e-3. optimizer_class: (Callable, optional) Optimizer class to use. Default is AdamW. optimizer_params: (dict, optional) Additional parameters for the optimizer. Default is None. scheduler_class: (Callable, optional) Scheduler class to use. Default is None. scheduler_params: (dict, optional) Additional parameters for the scheduler. Default is None. loss_fn: (Callable, optional) Loss function to use. Default is None. Returns: Output after the forward pass through the network. """ def __init__( self, backbone: nn.ModuleList | nn.Module, output_size, hidden_layers: tuple = (), activation=nn.ReLU(), learning_rate: float = 1e-3, optimizer_class: Callable = torch.optim.AdamW, optimizer_params: dict | None = None, scheduler_class: Callable | None = None, scheduler_params: dict | None = None, loss_fn: Callable | None = None, ): super().__init__() self.backbone = backbone modules = [] for hidden_units in hidden_layers: modules.append(nn.LazyLinear(hidden_units)) modules.append(activation) modules.append(nn.LazyLinear(output_size)) self.head = nn.ModuleList(modules) self.loss_fn = loss_fn self.LR = learning_rate # Optimizer and Scheduler configuration self.optimizer_class = optimizer_class self.optimizer_params = optimizer_params if optimizer_params is not None else {} self.scheduler_class = scheduler_class self.scheduler_params = scheduler_params if scheduler_params is not None else {} self.flat = nn.Flatten() def forward(self, x): if isinstance(x, dict): if isinstance(self.backbone, nn.Sequential): if len(x) == 1: x = self.backbone(next(iter(x.values()))) else: x = self.backbone(*x.values()) else: x = self.backbone(**x) else: x = self.backbone(x) if len(x.size()) == 4: x = self.flat(x) for layer in self.head: x = layer(x) return x def configure_optimizers(self): optimizer = self.optimizer_class( self.parameters(), lr=self.LR, **self.optimizer_params ) if self.scheduler_class is not None: scheduler = self.scheduler_class(optimizer, **self.scheduler_params) return [optimizer], [scheduler] else: return optimizer def training_step(self, batch, batch_idx): y = batch.pop("y") out = self(batch) loss = self.loss_fn(out, y) if getattr(self, "_trainer", None) is not None: self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): y = batch.pop("y") out = self(batch) loss = self.loss_fn(out, y) if getattr(self, "_trainer", None) is not None: self.log("valid_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
[docs] class Classifier(BaseModel): """A classification model built on top of the BaseModel, with additional accuracy, precision, recall, and F1-score tracking. Args: backbone: (Union[nn.ModuleList, nn.Module]) The feature extraction backbone model. n_classes: (int) Number of output classes. classifier_hidden_layers: (Tuple, optional) Tuple of hidden layer sizes for the classifier. Default is empty. classifier_activation: (nn.Module, optional) Activation function to use in the classifier. Default is ReLU. learning_rate: (float, optional) Learning rate for the optimizer. Default is 1e-3. optimizer_class: (Callable, optional) Optimizer class to use. Default is AdamW. optimizer_params: (dict, optional) Additional parameters for the optimizer. Default is None. scheduler_class: (Callable, optional) Scheduler class to use. Default is None. scheduler_params: (dict, optional) Additional parameters for the scheduler. Default is None. Returns: Output logits after the forward pass through the classifier. """ def __init__( self, backbone: nn.ModuleList | nn.Module, n_classes, classifier_hidden_layers=(), classifier_activation=nn.ReLU(), learning_rate=1e-3, optimizer_class: Callable = torch.optim.AdamW, optimizer_params: dict | None = None, scheduler_class: Callable | None = None, scheduler_params: dict | None = None, ): super().__init__( backbone=backbone, output_size=n_classes, hidden_layers=classifier_hidden_layers, activation=classifier_activation, learning_rate=learning_rate, optimizer_class=optimizer_class, optimizer_params=optimizer_params, scheduler_class=scheduler_class, scheduler_params=scheduler_params, loss_fn=nn.CrossEntropyLoss(), ) self.accuracy = torchmetrics.Accuracy( task="binary" if n_classes == 2 else "multiclass", num_classes=n_classes ) self.precision = torchmetrics.Precision( task="binary" if n_classes == 2 else "multiclass", num_classes=n_classes, average=None, # Per-class precision ) self.recall = torchmetrics.Recall( task="binary" if n_classes == 2 else "multiclass", num_classes=n_classes, average=None, # Per-class recall ) self.f1 = torchmetrics.F1Score( task="binary" if n_classes == 2 else "multiclass", num_classes=n_classes, average=None, # Per-class F1-score ) # For macro average metrics self.macro_precision = torchmetrics.Precision( task="binary" if n_classes == 2 else "multiclass", num_classes=n_classes, average="macro", # Macro-average precision ) self.macro_recall = torchmetrics.Recall( task="binary" if n_classes == 2 else "multiclass", num_classes=n_classes, average="macro", # Macro-average recall ) self.macro_f1 = torchmetrics.F1Score( task="binary" if n_classes == 2 else "multiclass", num_classes=n_classes, average="macro", # Macro-average F1-score ) self.prob = nn.Softmax(dim=1) def validation_step(self, batch, batch_idx): y = batch.pop("y") # Assuming that "y" is the ground truth labels out = self(batch) # Forward pass loss = self.loss_fn(out, y) out = self.prob(out) logits = torch.argmax(out, dim=1) # Calculate metrics accu = self.accuracy(logits, y) precision = self.precision(logits, y) recall = self.recall(logits, y) f1 = self.f1(logits, y) macro_precision = self.macro_precision(logits, y) macro_recall = self.macro_recall(logits, y) macro_f1 = self.macro_f1(logits, y) # Log loss and accuracy if getattr(self, "_trainer", None) is not None: self.log("valid_loss", loss, on_step=False, on_epoch=True, prog_bar=True) self.log("val_acc_step", accu, on_step=False, on_epoch=True, prog_bar=False) # Log precision, recall, and F1-score for each class if getattr(self, "_trainer", None) is not None: # Handle both binary (scalar) and multiclass (array) metrics if precision.dim() == 0: # Binary classification - metrics are scalars self.log( "val_precision_class_1", precision, on_step=False, on_epoch=True, prog_bar=False, ) self.log( "val_recall_class_1", recall, on_step=False, on_epoch=True, prog_bar=False, ) self.log( "val_f1_class_1", f1, on_step=False, on_epoch=True, prog_bar=False, ) else: # Multiclass classification - metrics are arrays for i in range(len(precision)): self.log( f"val_precision_class_{i}", precision[i], on_step=False, on_epoch=True, prog_bar=False, ) self.log( f"val_recall_class_{i}", recall[i], on_step=False, on_epoch=True, prog_bar=False, ) self.log( f"val_f1_class_{i}", f1[i], on_step=False, on_epoch=True, prog_bar=False, ) # Log macro averages if getattr(self, "_trainer", None) is not None: self.log( "val_macro_precision", macro_precision, on_step=False, on_epoch=True, prog_bar=False, ) self.log( "val_macro_recall", macro_recall, on_step=False, on_epoch=True, prog_bar=False, ) self.log( "val_macro_f1", macro_f1, on_step=False, on_epoch=True, prog_bar=False )
[docs] class Regressor(BaseModel): """A regression model built on top of the BaseModel, using mean squared error loss. Args: backbone: (Union[nn.ModuleList, nn.Module]) The feature extraction backbone model. output_dim: (int) Dimensionality of the regression output. regressor_hidden_layers: (Tuple, optional) Tuple of hidden layer sizes for the regressor. Default is empty. regressor_activation: (nn.Module, optional) Activation function to use in the regressor. Default is ReLU. learning_rate: (float, optional) Learning rate for the optimizer. Default is 1e-3. optimizer_class: (Callable, optional) Optimizer class to use. Default is AdamW. optimizer_params: (dict, optional) Additional parameters for the optimizer. Default is None. scheduler_class: (Callable, optional) Scheduler class to use. Default is None. scheduler_params: (dict, optional) Additional parameters for the scheduler. Default is None. Returns: Regression output after the forward pass. """ def __init__( self, backbone: nn.ModuleList | nn.Module, output_dim, regressor_hidden_layers=(), regressor_activation=nn.ReLU(), learning_rate=1e-3, optimizer_class: Callable = torch.optim.AdamW, optimizer_params: dict | None = None, scheduler_class: Callable | None = None, scheduler_params: dict | None = None, ): super().__init__( backbone=backbone, output_size=output_dim, hidden_layers=regressor_hidden_layers, activation=regressor_activation, learning_rate=learning_rate, optimizer_class=optimizer_class, optimizer_params=optimizer_params, scheduler_class=scheduler_class, scheduler_params=scheduler_params, loss_fn=nn.MSELoss(), )