Source code for pyeeglab.dataset.dataset

import os
import json
import logging
import hashlib
import pickle

from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import reduce
from multiprocessing import Pool, cpu_count
from operator import add, and_
from uuid import uuid4, uuid5, NAMESPACE_X500

from typing import Dict, List, Tuple

import mne
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker, Query

from .declarative_base import Base
from .file import File
from .metadata import Metadata
from .annotation import Annotation

from ..pipeline import Pipeline


[docs]@dataclass(init=False) class Dataset(ABC): path: str name: str version: str extensions: List[str] exclude_file: List[str] exclude_channels_set: List[str] exclude_channels_reference: List[str] exclude_sampling_frequency: List[int] minimum_annotation_duration: float session: Session query: Query pipeline: Pipeline = None def __init__( self, path: str, name: str, version: str = None, extensions: List[str] = [".edf"], exclude_file: List[str] = None, exclude_channels_set: List[str] = None, exclude_channels_reference: List[str] = None, exclude_sampling_frequency: List[str] = None, minimum_annotation_duration: float = None ) -> None: # Set basic attributes self.path = os.path.abspath(os.path.join(path, version)) self.name = name self.version = version # Set data set filter attributes self.extensions = extensions if extensions else [] self.exclude_file = exclude_file if exclude_file else [] self.exclude_channels_set = exclude_channels_set if exclude_channels_set else [] self.exclude_channels_reference = exclude_channels_reference if exclude_channels_reference else [] self.exclude_sampling_frequency = exclude_sampling_frequency if exclude_sampling_frequency else [] self.minimum_annotation_duration = minimum_annotation_duration if minimum_annotation_duration else 0 logging.info("Init dataset '%s'@'%s' at '%s'", self.name, self.version, self.path) # Make workspace directory logging.debug("Make .pyeeglab directory") workspace = os.path.join(self.path, ".pyeeglab") os.makedirs(workspace, exist_ok=True) logging.debug("Make .pyeeglab/cache directory") os.makedirs(os.path.join(workspace, "cache"), exist_ok=True) logging.debug("Set MNE log .pyeeglab/mne.log") mne.set_log_file(os.path.join(workspace, "mne.log"), overwrite=False) # Index data set files self.index() def __getstate__(self): # Workaround for unpickable sqlalchemy.orm.session # during multiprocess dataset loading state = self.__dict__.copy() for attribute in ["session", "query"]: if hasattr(self, attribute): del state[attribute] return state
[docs] @abstractmethod def download(self, user: str = None, password: str = None) -> None: pass
[docs] def index(self) -> None: # Init index session logging.debug("Make index session") connection = os.path.join(self.path, ".pyeeglab", "index.sqlite3") connection = create_engine("sqlite:///" + connection) Base.metadata.create_all(connection) self.session = sessionmaker(bind=connection)() # Open multiprocess pool logging.info("Index data set directory") pool = Pool(cpu_count()) # Get files path from data set path paths = [ os.path.join(directory, filename) for directory, _, filenames in os.walk(self.path) for filename in filenames ] # Get Files instances form paths, filtering already indexed files = self.session.query(File).all() files = [file.uuid for file in files] files = [ file for file in pool.map(self._get_file, paths) if file.uuid not in files ] for file in files: logging.debug("Add file %s to index", file.uuid) # Filter raw data files by extension raws = [ file for file in files if os.path.splitext(file.path)[-1] in self.extensions ] # Get metadata and annotation for data files metadatas = pool.map(self._get_metadata, raws) annotations = pool.map(self._get_annotation, raws) # Close multiprocess pool pool.close() pool.join() # Commit insertions to index commits = files + metadatas + reduce(add, annotations, []) if commits: logging.info("Commit insertions to index") self.session.add_all(commits) self.session.commit() logging.info("Index data set completed") # Init default query logging.debug("Init default query") self.query = self.session.query(File, Metadata, Annotation).\ join(File.meta).\ join(File.annotations).\ filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)).\ filter(~Metadata.sampling_frequency.in_(self.exclude_sampling_frequency)).\ filter((Annotation.end - Annotation.begin) >= self.minimum_annotation_duration) # Filter exclude file paths for file in self.exclude_file: self.query = self.query.filter(~File.path.like("%{}%".format(file))) logging.debug("SQL query representation: '%s'", str(self.query).replace("\n", ""))
def _get_file(self, path: str) -> File: return File( uuid=str(uuid5(NAMESPACE_X500, path)), path=path, extension=os.path.splitext(path)[-1] ) def _get_metadata(self, file: File) -> Metadata: logging.debug("Add file %s metadata to index", file.uuid) with file as reader: info = reader.info metadata = Metadata( file_uuid=file.uuid, duration=reader.n_times/info["sfreq"], channels_set=json.dumps(info["ch_names"]), sampling_frequency=info["sfreq"], max_value=reader.get_data().max(), min_value=reader.get_data().min(), ) return metadata def _get_annotation(self, file: File) -> List[Annotation]: logging.debug("Add file %s annotations to index", file.uuid) with file as reader: annotations = [ Annotation( uuid=str(uuid4()), file_uuid=file.uuid, begin=annotation[0], end=annotation[0]+annotation[1], label=annotation[2], ) for annotation in reader.annotations ] return annotations @property def environment(self) -> Dict: min_max = self.signal_min_max_range return { "channels_set": self.maximal_channels_subset, "lowest_frequency": self.lowest_frequency, "min_value": min_max[0], "max_value": min_max[1], } @property def lowest_frequency(self) -> float: frequency = self.query.all() frequency = min([ f[1].sampling_frequency for f in frequency ], default=0) return frequency @property def maximal_channels_subset(self) -> List[str]: channels = self.query.group_by(Metadata.channels_set).all() channels = [ frozenset(json.loads(channel[1].channels_set)) for channel in channels ] channels = reduce(and_, channels) channels = channels - frozenset(self.exclude_channels_set) channels = sorted(channels) return channels @property def signal_min_max_range(self) -> Tuple[float]: min_max = self.query.all() min_max = [m[1] for m in min_max] min_max = tuple([ min([m.min_value for m in min_max], default=0), max([m.max_value for m in min_max], default=0), ]) return min_max
[docs] def set_pipeline(self, pipeline: Pipeline) -> "Dataset": self.pipeline = pipeline self.pipeline.environment.update(self.environment) return self
[docs] def set_minimum_event_duration(self, duration: float) -> "Dataset": logging.warning("This function will be deprecated in the near future") self.minimum_annotation_duration = duration return self
[docs] def load(self) -> Dict: # Compute cache path cache = os.path.join(self.path, ".pyeeglab", "cache") # Compute cache key logging.info("Compute cache key") name = self.__class__.__name__.lower() if name.endswith("dataset"): name = name[:-len("dataset")] key = [hash(self), hash(self.pipeline)] key = [str(k).encode() for k in key] key = [hashlib.md5(k).hexdigest()[:10] for k in key] key = list(zip(["loader", "pipeline"], key)) key = ["_".join(k) for k in key] key = name + "_" + "_".join(key) logging.info("Computed cache key: %s", key) # Load file cache cache = os.path.join(cache, key + ".pkl") if os.path.exists(cache): logging.info("Cache file found at %s", cache) with open(cache, "rb") as reader: try: logging.info("Loading cache file") return pickle.load(reader) except: logging.error("Loading cache file failed") # Cache file not found, start preprocessing logging.info("Cache file not found, genereting new one") data = [row[2] for row in self.query.all()] data = self.pipeline.run(data) with open(cache, "wb") as file: logging.info("Dumping cache file") pickle.dump(data, file) return data
def __hash__(self) -> int: key = [self.path, self.version, self.minimum_annotation_duration] key += self.exclude_file key += self.exclude_channels_set key += self.exclude_channels_reference key += self.exclude_sampling_frequency key = json.dumps(key).encode() key = hashlib.md5(key).hexdigest() key = int(key, 16) return key