"""
Simple data loading utilities for the eye-tracking collection.
Collection data lives in this repo at ``data/collection`` as Parquet files
(tracked with Git LFS). You can pass a custom path or use the default.
Column conventions:
- Primary key (pk): columns starting with ``group_``
- Labels: columns ending with ``_label``
- Meta: columns starting with ``meta_``
"""
import json
from pathlib import Path
from typing import Any
import pandas as pd
#: Default root directory for collection Parquet files (``data/collection`` in the repo, Git LFS).
DEFAULT_COLLECTION_DIR = Path("data/collection")
def _classify_dataset_type(dataset_name: str) -> str:
"""Classify dataset type by name suffix: 'gaze' or 'fixation'."""
if dataset_name.endswith("_gaze") or dataset_name.endswith("_gazes"):
return "gaze"
if dataset_name.endswith("_fixations") or dataset_name.endswith("_fixation"):
return "fixation"
# Default: treat as fixation
return "fixation"
[docs]
def list_datasets(
collection_dir: str | Path | None = None,
*,
include_extensive_collection: bool = True,
extensive_collection_only: bool = False,
include_extracted_fixations: bool = True,
extracted_fixations_only: bool = False,
dataset_type: str | None = None,
) -> list[str]:
"""List available dataset names in the collection directory.
Parameters
----------
collection_dir : path, optional
Root directory containing collection Parquet files.
Defaults to ``data/collection`` (repo data tracked with Git LFS).
include_extensive_collection : bool, default True
If True, also search in extensive_collection subfolder.
Ignored when extensive_collection_only or extracted_fixations_only is True.
extensive_collection_only : bool, default False
If True, list only datasets from extensive_collection subfolder
(main directory is not scanned).
include_extracted_fixations : bool, default True
If True, also search in extracted_fixations subfolder.
Ignored when extensive_collection_only or extracted_fixations_only is True.
extracted_fixations_only : bool, default False
If True, list only datasets from extracted_fixations subfolder
(main directory is not scanned).
dataset_type : str, optional
If "gaze", return only gaze datasets (names ending with _gaze/_gazes).
If "fixation", return only fixation datasets (names ending with
_fixations/_fixation or default). If None, return all.
Returns
-------
list of str
Sorted list of dataset names (without .parquet extension).
"""
collection_path = (
Path(collection_dir) if collection_dir is not None else DEFAULT_COLLECTION_DIR
)
dataset_names = set()
if extracted_fixations_only:
extracted_dir = collection_path / "extracted_fixations"
if extracted_dir.exists():
for f in extracted_dir.glob("*.parquet"):
dataset_names.add(f.stem)
elif extensive_collection_only:
extensive_dir = collection_path / "extensive_collection"
if extensive_dir.exists():
for f in extensive_dir.glob("*.parquet"):
dataset_names.add(f.stem)
else:
for f in collection_path.glob("*.parquet"):
dataset_names.add(f.stem)
if include_extensive_collection:
extensive_dir = collection_path / "extensive_collection"
if extensive_dir.exists():
for f in extensive_dir.glob("*.parquet"):
dataset_names.add(f.stem)
if include_extracted_fixations:
extracted_dir = collection_path / "extracted_fixations"
if extracted_dir.exists():
for f in extracted_dir.glob("*.parquet"):
dataset_names.add(f.stem)
if dataset_type is not None:
dataset_names = {
name
for name in dataset_names
if _classify_dataset_type(name) == dataset_type
}
return sorted(dataset_names)
[docs]
def load_dataset(
dataset_name: str,
collection_dir: str | Path | None = None,
*,
normalize: bool = True,
) -> tuple[pd.DataFrame, dict]:
"""Load a collection dataset by name.
Parameters
----------
dataset_name : str
Name of the dataset (e.g. "ASD_ready_data_fixations").
Will search for {dataset_name}.parquet in collection_dir.
collection_dir : path, optional
Root directory containing collection Parquet files.
Defaults to ``data/collection`` (repo data tracked with Git LFS).
normalize : bool, default True
If True and dataset has unnormalized x/y columns, normalize them
and rename to norm_pos_x/norm_pos_y.
Returns
-------
tuple (DataFrame, meta_info)
- DataFrame: loaded and optionally normalized data
- meta_info: dict with 'pk', 'labels', 'meta' column lists and 'info'
(from collection_dir/meta.json under key dataset_name, if present).
"""
collection_path = (
Path(collection_dir) if collection_dir is not None else DEFAULT_COLLECTION_DIR
)
dataset_path = collection_path / f"{dataset_name}.parquet"
if not dataset_path.exists():
# Try in extensive_collection
extensive_path = (
collection_path / "extensive_collection" / f"{dataset_name}.parquet"
)
if extensive_path.exists():
dataset_path = extensive_path
else:
# Try in extracted_fixations
extracted_path = (
collection_path / "extracted_fixations" / f"{dataset_name}.parquet"
)
if extracted_path.exists():
dataset_path = extracted_path
else:
raise FileNotFoundError(
f"Dataset '{dataset_name}' not found in {collection_path}, "
f"{collection_path / 'extensive_collection'}, or "
f"{collection_path / 'extracted_fixations'}"
)
df = pd.read_parquet(dataset_path)
# Parquet preserves types; ensure numeric for x/y if present (e.g. from older exports)
if "x" in df.columns and not pd.api.types.is_numeric_dtype(df["x"]):
df["x"] = pd.to_numeric(
df["x"].astype(str).str.replace(",", "."), errors="coerce"
)
if "y" in df.columns and not pd.api.types.is_numeric_dtype(df["y"]):
df["y"] = pd.to_numeric(
df["y"].astype(str).str.replace(",", "."), errors="coerce"
)
# Handle left/right eye columns
if "x_left" in df.columns and "x_right" in df.columns:
if "x" not in df.columns:
df["x"] = (df["x_left"] + df["x_right"]) / 2
if "y" not in df.columns:
df["y"] = (df["y_left"] + df["y_right"]) / 2
# Normalize if requested and needed
if normalize and "x" in df.columns and "y" in df.columns:
if "norm_pos_x" not in df.columns:
max_x = df["x"].max()
max_y = df["y"].max()
df["norm_pos_x"] = df["x"] / max_x if max_x > 0 else df["x"]
df["norm_pos_y"] = df["y"] / max_y if max_y > 0 else df["y"]
# Build meta info
meta_info = {
"pk": get_pk(df),
"labels": get_labels(df),
"meta": get_meta(df),
"info": _load_meta_info(collection_path, dataset_name),
}
return df, meta_info
def _load_meta_info(collection_path: Path, dataset_name: str) -> Any | None:
"""Load meta.json from collection dir and return value for dataset_name key."""
meta_path = collection_path / "meta.json"
if not meta_path.exists():
return None
try:
with open(meta_path, encoding="utf-8") as f:
data = json.load(f)
return data.get(dataset_name)
except (json.JSONDecodeError, OSError):
return None
[docs]
def get_pk(df: pd.DataFrame) -> list[str]:
r"""Get primary key column names (columns starting with ``group\_``).
Parameters
----------
df : DataFrame
Benchmark dataset DataFrame.
Returns
-------
list of str
Primary key column names.
"""
return [col for col in df.columns if col.startswith("group_")]
[docs]
def get_labels(df: pd.DataFrame) -> list[str]:
"""Get label column names (columns ending with _label).
Parameters
----------
df : DataFrame
Benchmark dataset DataFrame.
Returns
-------
list of str
Label column names.
"""
return [col for col in df.columns if col.endswith("_label")]