VitNetWithCrossAttention

class eyefeatures.deep.models.VitNetWithCrossAttention(CNN, RNN, cross_attention_fusion_mode='concat', activation=None, embed_dim=128, return_attention_weights=False)[source]

Bases: VitNet

Child class extending VitNet by adding cross-attention between image and sequence representations.

Parameters:
  • CNN – (nn.Module) CNN backbone for processing image data.

  • RNN – (nn.Module) RNN backbone for processing sequence data.

  • input_dim – (int) Input dimension for the RNN.

  • projected_dim – (int) Dimension of the projected sequence representation.

  • cross_attention_fusion_mode – (str, optional) Fusion mode after cross-attention (‘concat’ or ‘add’). Default is ‘concat’.

  • activation – (nn.Module, optional) Activation function applied after fusion. Default is None.

  • embed_dim – (int, optional) Embedding dimension for the projected features. Default is 128.

  • return_attention_weights – (bool, optional) Whether to return attention weights. Default is False.