from collections.abc import Callable
from copy import copy
from functools import partial
import numpy as np
import pandas as pd
import torch
from numpy.typing import ArrayLike
from torch.utils.data import Dataset
from torch_geometric.data import Data
from tqdm import tqdm
from eyefeatures.features.feature_maps import get_gafs, get_heatmaps, get_mtfs
from eyefeatures.preprocessing.base import BaseFixationPreprocessor
from eyefeatures.utils import _split_dataframe
from eyefeatures.visualization.static_visualization import get_visualizations
def _coord_to_grid(coords: np.array, xlim: tuple, ylim: tuple, shape: tuple):
"""Maps 2D coordinates to grid indices based on the grid resolution.
Args:
coords (np.ndarray): Array of coordinates to map.
xlim: The x-axis limits (x_min, x_max).
ylim: The y-axis limits (y_min, y_max).
shape: The shape of the grid (rows, cols).
Returns:
tuple(int, int): A tuple (i, j) - the grid indices
corresponding to the coordinates.
"""
i = (((coords[:, 0] - xlim[0]) / (xlim[1] - xlim[0])) * shape[0]).astype(int)
j = (((coords[:, 1] - ylim[0]) / (ylim[1] - ylim[0])) * shape[1]).astype(int)
return i, j
def _cell_index(i: int, j: int, shape: tuple[int, int]):
"""Maps grid indices (i, j) to a 1D cell index based on the grid shape.
Args:
i (int): Row index in the grid.
j (int): Column index in the grid.
shape (tuple(int,int)): The shape of the grid (rows, cols).
Returns:
int: The 1D cell index.
"""
return i * shape[1] + j
def _calculate_cell_center(i: int, j: int, xlim: tuple, ylim: tuple, shape: tuple):
"""Calculates the center coordinates of a grid cell.
Args:
i (int): Row index in the grid.
j (int): Column index in the grid.
xlim: The x-axis limits (x_min, x_max).
ylim: The y-axis limits (y_min, y_max).
shape: The shape of the grid (rows, cols).
Returns:
tuple(float,float):
-------
x_center, y_center: Tuple[float, float]
Center coordinates of the grid cell.
"""
cell_width = (xlim[1] - xlim[0]) / shape[0]
cell_height = (ylim[1] - ylim[0]) / shape[1]
x_center = xlim[0] + (i + 0.5) * cell_width
y_center = ylim[0] + (j + 0.5) * cell_height
return x_center, y_center
def _calculate_length_vectorized(coords: np.array):
"""
Calculates the Euclidean distance between consecutive points in 2D space.
:param coords: Array of coordinates with shape (n, 2).
Returns
-------
lengths: np.array
Euclidean distances between consecutive points.
"""
# Calculate the difference between consecutive points
dx = coords[1:, 0] - coords[:-1, 0]
dy = coords[1:, 1] - coords[:-1, 1]
# Calculate the Euclidean distance between consecutive points
lengths = np.sqrt(dx**2 + dy**2)
return lengths
def create_edge_list_and_cumulative_features(
df, add_duration, x_col, y_col, xlim, ylim, shape, directed=True
):
"""Creates an edge list and computes cumulative node
features (total duration, total saccade lengths, and cell
center coordinates). These features are normalized by their
respective maximum values. Also computes edge features based
on the sum of edge lengths.
Args:
df: DataFrame containing the coordinates and other node features.
x_col: Column name in df for the x coordinates.
y_col: Column name in df for the y coordinates.
add_duration: Column name in df for the duration between
consecutive points (optional).
xlim: Tuple (x_min, x_max) defining the bounds for the x-axis.
ylim: Tuple (y_min, y_max) defining the bounds for the y-axis.
shape: Tuple (x_res, y_res) defining the resolution of the grid.
directed: If True, the graph is directional; if False, bidirectional
edges are created.
Returns:
edge_list: List of edges as pairs of node indices.
edge_features: Normalized edge features (sum of edge lengths).
node_mapping: Mapping of old node indices to new compacted indices.
cumulative_node_features: Normalized cumulative node features:
* ``total_duration``: Total duration at each node, normalized.
* ``total_saccade_length_to``: Total saccade length directed to
each node, normalized.
* ``total_saccade_length_from``: Total saccade length originating
from each node, normalized.
* ``cell_centers``: Coordinates of the center of each cell.
"""
coords = df[[x_col, y_col]].values
i, j = _coord_to_grid(coords, xlim, ylim, shape)
grid_indices = _cell_index(i, j, shape)
# print(grid_indices)
unique_nodes = np.unique(grid_indices)
node_mapping = {node: idx for idx, node in enumerate(unique_nodes)}
num_nodes = len(unique_nodes)
# Initialize cumulative feature arrays
total_durations = np.zeros(num_nodes)
total_saccade_length_to = np.zeros(num_nodes)
total_saccade_length_from = np.zeros(num_nodes)
cell_centers = np.zeros((num_nodes, 2))
# Dictionary to accumulate edge lengths
edge_length_sum = {}
# Create edge list and calculate cumulative features
edge_list = []
lengths = _calculate_length_vectorized(coords)
for k in range(len(df) - 1):
src_node = node_mapping[grid_indices[k]]
dst_node = node_mapping[grid_indices[k + 1]]
# Handle self-loops: Only update duration
if src_node == dst_node:
if add_duration:
total_durations[src_node] += df["duration"].iloc[k]
continue # Skip the rest of the loop, don't add an edge
# Add edge if not a self-loop
edge = (src_node, dst_node)
edge_list.append(edge)
if not directed:
reverse_edge = (dst_node, src_node)
edge_list.append(reverse_edge)
edge_length_sum[reverse_edge] = (
edge_length_sum.get(reverse_edge, 0) + lengths[k]
)
# Accumulate the length of the edge
edge_length_sum[edge] = edge_length_sum.get(edge, 0) + lengths[k]
# Update cumulative saccade features (only for non-self-loops)
total_saccade_length_to[dst_node] += lengths[k]
total_saccade_length_from[src_node] += lengths[k]
# Calculate the center coordinates of the cell
i_node, j_node = i[k], j[k]
x_center, y_center = _calculate_cell_center(i_node, j_node, xlim, ylim, shape)
cell_centers[src_node] = [x_center, y_center]
# Normalize cumulative features by their respective maximum values
if np.max(total_durations) > 0:
total_durations /= np.max(total_durations)
if np.max(total_saccade_length_to) > 0:
total_saccade_length_to /= np.max(total_saccade_length_to)
if np.max(total_saccade_length_from) > 0:
total_saccade_length_from /= np.max(total_saccade_length_from)
# Normalize edge features (sum of lengths) by their maximum value
if edge_length_sum: # Ensure there are edges to normalize
max_edge_length_sum = np.max(list(edge_length_sum.values()))
edge_features = [
edge_length_sum[edge] / max_edge_length_sum for edge in edge_list
]
else:
edge_features = []
# Combine cumulative features into a dictionary
cumulative_node_features = {
"total_duration": total_durations,
"total_saccade_length_to": total_saccade_length_to,
"total_saccade_length_from": total_saccade_length_from,
"cell_centers": cell_centers,
}
return edge_list, edge_features, node_mapping, cumulative_node_features
def create_graph_data_from_dataframe(
df, y, x_col, y_col, add_duration, xlim, ylim, shape, directed=True
):
"""Converts a DataFrame into a PyTorch Geometric Data object for GCN training.
Includes cumulative node features (total duration, total saccade
length to/from node, and cell center coordinates).
Edge features are based on the sum of lengths of corresponding edges.
Args:
df: DataFrame containing the coordinates and other node features.
x_col: Column name in df for the x coordinates.
y_col: Column name in df for the y coordinates.
add_duration: Column name in df for the duration between consecutive
points (optional).
xlim: Tuple (x_min, x_max) defining the bounds for the x-axis.
ylim: Tuple (y_min, y_max) defining the bounds for the y-axis.
shape: Tuple (x_res, y_res) defining the resolution of the grid.
directed: If True, the graph is directional; if False, bidirectional
edges are created.
Returns:
A PyTorch Geometric Data object containing the graph and its features.
"""
# Get edge list and cumulative features
(
edge_list,
edge_features,
node_mapping,
cumulative_node_features,
) = create_edge_list_and_cumulative_features(
df, add_duration, x_col, y_col, xlim, ylim, shape, directed
)
# Combine cumulative features into a feature matrix
node_features = np.hstack(
[
cumulative_node_features["cell_centers"], # Cell center coordinates
cumulative_node_features["total_duration"].reshape(-1, 1),
cumulative_node_features["total_saccade_length_to"].reshape(-1, 1),
cumulative_node_features["total_saccade_length_from"].reshape(-1, 1),
]
)
# Convert node features to PyTorch tensor
x = torch.tensor(node_features, dtype=torch.float)
# Convert edge list and edge features to PyTorch tensors
edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(edge_features, dtype=torch.float).view(
-1, 1
) # Shape [num_edges, 1]
mapping = torch.tensor(sorted(node_mapping, key=node_mapping.get))
# Create PyTorch Geometric Data object with edge weights
data = Data(
x=x,
y=torch.tensor(y),
edge_index=edge_index,
edge_attr=edge_attr,
mapping=mapping,
)
return data
# Representation types with zoom options:
# - *_fixed: uses fixed [0,1] coordinate space (shows absolute position)
# - *_zoomed: zooms to data range (fills the image with the scanpath region)
_representations = {
# Heatmaps
"heatmap": get_heatmaps, # Default: fixed [0,1] space (backward compat)
"heatmap_fixed": partial(get_heatmaps, zoom_to_data=False),
"heatmap_zoomed": partial(get_heatmaps, zoom_to_data=True),
# Baseline visualization
"baseline_visualization": partial(
get_visualizations, pattern="baseline"
), # Default: zoomed (backward compat)
"baseline_fixed": partial(
get_visualizations, pattern="baseline", zoom_to_data=False
),
"baseline_zoomed": partial(
get_visualizations, pattern="baseline", zoom_to_data=True
),
# AOI visualization
"aoi_visualization": partial(get_visualizations, pattern="aoi"),
"aoi_fixed": partial(get_visualizations, pattern="aoi", zoom_to_data=False),
"aoi_zoomed": partial(get_visualizations, pattern="aoi", zoom_to_data=True),
# Saccade visualization (with sequential colormap)
"saccade_visualization": partial(
get_visualizations, pattern="saccades"
), # Default: zoomed
"saccade_fixed": partial(
get_visualizations, pattern="saccades", zoom_to_data=False
),
"saccade_zoomed": partial(
get_visualizations, pattern="saccades", zoom_to_data=True
),
# GAF (Gramian Angular Field) and MTF (Markov Transition Field) 2D maps for DL (no zoom variant)
"gaf_fixed": get_gafs,
"mtf_fixed": get_mtfs,
}
[docs]
class Dataset2D(Dataset):
"""Custom dataset for 2D image-based representations derived from gaze data.
Args:
X: Input data.
Y: Labels for the data.
pk: List of primary keys for grouping.
shape: Shape of the images.
representations: List of representation types.
upload_to_cuda: If True, upload the data to the GPU. Default: False.
transforms: Transformations to apply to the data.
"""
def __init__(
self,
X: pd.DataFrame,
Y: ArrayLike,
x: str,
y: str,
pk: list[str],
shape: tuple[int] | int,
representations: list[str],
upload_to_cuda: bool = False,
transforms=None,
):
self.pmk = pk
rep_tensors = []
for i in representations:
rep_data = _representations[i](X, x, y, pk=pk, shape=shape)
rep_tensor = torch.tensor(rep_data, dtype=torch.float32)
# Only add channel dimension if not already present
# Heatmaps return (n, h, w), visualizations return (n, c, h, w)
if rep_tensor.ndim == 3:
rep_tensor = rep_tensor.unsqueeze(1)
rep_tensors.append(rep_tensor)
self.X = torch.cat(rep_tensors, dim=1)
self.channels_dim = self.X.shape[1]
print(f"Number of channels = {self.channels_dim}.")
if not isinstance(Y, pd.Series):
Y = Y.set_index(pk).squeeze(axis=0)
self.y = Y.sort_index().values
if np.issubdtype(self.y.dtype, np.integer):
self.y = torch.tensor(self.y, dtype=torch.long)
else:
self.y = torch.tensor(self.y, dtype=torch.float)
if upload_to_cuda:
self.X = self.X.cuda()
self.y = self.y.cuda()
self.transforms = transforms
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx: int):
X = self.X[idx, :, :, :]
label = self.y[idx]
if self.transforms is not None:
X = self.transforms(X)
return {
"images": X,
"y": label,
}
def collate_fn(self, batch):
images = torch.stack([x["images"] for x in batch])
y = torch.tensor([x["y"] for x in batch])
return {"images": images, "y": y}
def _get_features(X, features, x, y, t, pk):
# Handle None features case - return only x and y coordinates
if features is None:
output = []
groups = _split_dataframe(X, pk)
for group_id, group_X in tqdm(groups):
output.append(group_X[[x, y]].values)
return output
preprocessor = BaseFixationPreprocessor(x, y, t, pk)
features_to_get = copy(features)
for i in features:
if i in X.columns:
features_to_get.remove(i)
output = []
groups = _split_dataframe(X, pk)
for group_id, group_X in tqdm(groups):
cur_X = preprocessor._compute_feats(group_X, features_to_get)
output.append(cur_X[[x, y] + features].values)
return output
[docs]
class DatasetTimeSeries(Dataset):
"""Custom dataset for time-series data.
Args:
X: Input time-series data.
Y: Labels for the data.
pk: Primary keys for grouping.
features: List of features to extract. If None, only x and y coordinates are used.
transforms: Transformations to apply to the data.
max_length: maximum length of scanpath.
"""
def __init__(
self,
X: pd.DataFrame,
Y: ArrayLike,
x: str,
y: str,
pk: list[str],
features: list[str] | None = None,
transforms: Callable = None,
max_length: int = 10,
):
self.pmk = pk
self.X = _get_features(X, features, x, y, t=None, pk=pk)
if not isinstance(Y, pd.Series):
Y = Y.set_index(pk).squeeze(axis=0)
self.Y = Y.sort_index().values
if np.issubdtype(self.Y.dtype, np.integer):
self.Y = torch.tensor(self.Y, dtype=torch.long)
else:
self.Y = torch.tensor(self.Y, dtype=torch.float)
self.n_features = 2 + (len(features) if features is not None else 0)
self.transforms = transforms
self.max_length = max_length
def __len__(self):
return len(self.X)
def __getitem__(self, idx: int):
X = self.X[idx]
label = self.Y[idx]
if self.transforms:
X = self.transforms(X)
return {
"sequences": torch.tensor(X, dtype=torch.float),
"y": label,
}
def collate_fn(self, batch):
if self.max_length is None:
lengths = [x["sequences"].shape[0] for x in batch]
max_len = max(lengths)
padded_batch = [
torch.cat(
[
x["sequences"],
torch.zeros(max_len - x["sequences"].shape[0], self.n_features),
],
axis=0,
)
for x in batch
]
else:
max_len = self.max_length
lengths = [min(x["sequences"].shape[0], max_len) for x in batch]
padded_batch = [
torch.cat(
[
x["sequences"][: self.max_length],
torch.zeros(
max_len - x["sequences"][: self.max_length].shape[0],
self.n_features,
),
],
axis=0,
)
for x in batch
]
y = torch.tensor([x["y"] for x in batch])
return {
"sequences": torch.stack(padded_batch),
"lengths": torch.tensor(lengths),
"y": y,
}
[docs]
class TimeSeries_2D_Dataset(Dataset):
"""Composite dataset that combines image and time-series data.
Args:
image_dataset: Dataset containing image data.
sequence_dataset: Dataset containing sequence data.
"""
def __init__(self, image_dataset: Dataset, sequence_dataset: Dataset):
# Ensure both datasets have the same length
assert len(image_dataset) == len(
sequence_dataset
), "Datasets must be of the same length"
self.image_dataset = image_dataset
self.sequence_dataset = sequence_dataset
def __len__(self):
# The length of the composite dataset is the same as either individual dataset
return len(self.image_dataset)
def __getitem__(self, idx):
# Fetch the data from both datasets using the same index
image = self.image_dataset.X[idx, :, :, :]
sequence = self.sequence_dataset.X[idx]
y = self.image_dataset.y[idx]
# Use float32 so batch matches model parameters (avoid Double vs Float mismatch)
return {
"images": torch.as_tensor(image, dtype=torch.float32),
"sequences": torch.as_tensor(sequence, dtype=torch.float32),
"y": y,
}
def collate_fn(self, batch):
lengths = [x["sequences"].shape[0] for x in batch]
max_len = max(lengths)
padded_batch = [
torch.cat(
[
x["sequences"],
torch.zeros(
max_len - x["sequences"].shape[0],
self.sequence_dataset.n_features,
dtype=torch.float32,
),
],
axis=0,
)
for x in batch
]
y = torch.tensor([x["y"] for x in batch])
return {
"sequences": torch.stack(padded_batch),
"lengths": torch.tensor(lengths),
"images": torch.stack([x["images"] for x in batch]),
"y": y,
}
[docs]
class GridGraphDataset(Dataset):
"""Custom dataset for generating grid-based graph
representations from spatial coordinates.
Args:
X: Input dataframe.
Y: Labels for the data.
x: X coordinate column name.
y: Y coordinate column name.
pk: Primary keys for grouping.
x_col, y_col: Column names for x and y coordinates.
add_duration: Column name for time durations.
xlim: Limits of the x-axis.
ylim: Limits of the y-axis.
shape: Shape of the grid.
directed: Whether the graph is directed.
transforms: Transformations to apply to the data.
"""
def __init__(
self,
X: pd.DataFrame,
Y: ArrayLike,
x: str,
y: str,
add_duration: str,
pk: list[str],
xlim: tuple[float, float] = (0, 1),
ylim: tuple[float, float] = (0, 1),
shape: tuple[int, int] = (10, 10),
directed: bool = True,
transforms: Callable = None,
):
super().__init__()
self.transform = transforms
self.pk = pk
self.directed = directed
if not isinstance(Y, pd.Series):
Y = Y.set_index(pk).sort_index().squeeze(axis=0)
Y = Y.values
self.X = self.get_graphs(X, Y, x, y, add_duration, xlim, ylim, shape)
def get_graphs(self, X, Y, x_col, y_col, add_duration, xlim, ylim, shape):
groups = _split_dataframe(X, pk=self.pk)
graphs = []
for i, (group_id, cur_X) in tqdm(enumerate(groups), desc="Getting graphs..."):
graphs.append(
create_graph_data_from_dataframe(
cur_X,
Y[i],
x_col,
y_col,
add_duration,
xlim,
ylim,
shape,
directed=self.directed,
)
)
return graphs
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx]
def collate_fn(self, batch):
return batch