Coverage for src/signal_edges/signal/sample/sample.py: 30%
144 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-04-21 11:16 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-04-21 11:16 +0000
1"""The sample module introduces the :class:`.Sample` and :class:`.Waveform` classes, a collection of utilities
2specially intended for analysis of signals in notebooks, although they are also suitable for use outside these.
4In this context, a `sample` is a collection of `waveforms`, each `sample` is stored in a directory,
5managed by :class:`.SampleManager`, with the following structure:
7.. code-block::
9 samples
10 sample_000
11 waveform_000.npz
12 waveform_001.npz
13 ...
14 sample_001
15 waveform_000.npz
16 waveform_001.npz
17 ...
18 ...
20Each waveform is a ``.npz`` file generated with :meth:`numpy.savez_compressed` with additional metadata.
22Each sample in the ``samples`` directory is identified by an integer number, and similarly each `waveform` in
23each sample is identified by an integer number. For example notebooks where :class:`.Sample`, :class:`.SampleManager`
24and :class:`.Waveform` are used refer to the `other/notebook` folders in the repository."""
26import os
27import re
28import shutil
29import sys
30import tempfile
31from collections.abc import Sequence
32from typing import Any, Literal, TypeAlias, cast
34import numpy as np
35import numpy.typing as npt
37try:
38 import IPython
39 import IPython.display
40except ImportError:
41 pass
44from ... import plotter as sep
45from ...exceptions import SignalError
46from ..edges import Edge
47from ..signal import Signal
48from ..state_levels import StateLevels
51class Waveform:
52 """Utility class to handle a waveforms within :class:`.Sample`."""
54 ## Private API #####################################################################################################
55 def __init__(self, sid: int, wid: int, path: str) -> None:
56 """Class constructor.
58 :param wid: The waveform identifier.
59 :param path: The path to the waveform file."""
60 #: Sample identifier.
61 self._sid = sid
62 #: Waveform identifier.
63 self._wid = wid
65 self._meta = {}
66 self._hvalues = np.asarray([1, 2, 3])
67 self._vvalues = np.asarray([1, 2, 3])
69 # Load waveform file and extract relevant data.
70 with np.load(path) as data:
71 #: The values of the horizontal axis.
72 self._hvalues = data["hvalues"]
73 #: The values of the vertical axis.
74 self._vvalues = data["vvalues"]
75 #: The metadata values.
76 self._meta = {k: v for k, v in data.items() if k not in ("hvalues", "vvalues")}
78 ## Protected API ###################################################################################################
80 ## Public API ######################################################################################################
81 @property
82 def sid(self) -> int:
83 """The sample identifier of the sample that contains this waveform.
85 :return: Sample identifier."""
86 return self._sid
88 @property
89 def wid(self) -> int:
90 """The waveform identifier within the sample.
92 :return: Waveform identifier."""
93 return self._wid
95 @property
96 def hvalues(self) -> npt.NDArray[np.float_]:
97 """The values of the horizontal axis.
99 :return: The array with the values."""
100 return self._hvalues
102 @property
103 def vvalues(self) -> npt.NDArray[np.float_]:
104 """The values of the vertical axis.
106 :return: The array with the values."""
107 return self._vvalues
109 @property
110 def meta(self) -> dict[str, Any]:
111 """Dictionary with the metadata of the waveform.
113 :return: Metadata of the waveform."""
114 return self._meta
117class Sample:
118 """Sample class with its underlying instances of :class:`.Waveform` for each waveform."""
120 ## Private API #####################################################################################################
121 def __init__(self, sid: int, waveforms: Sequence[Waveform]) -> None:
122 """Class constructor.
124 :param sid: The sample identifier.
125 :param waveforms: The waveforms for the sample."""
126 #: Sample identifier.
127 self._sid = sid
128 #: The waveforms of the sample.
129 self._waveforms: dict[int, Waveform] = {wfm.wid: wfm for wfm in waveforms}
131 ## Protected API ###################################################################################################
133 ## Public API ######################################################################################################
134 @property
135 def sid(self) -> int:
136 """The sample identifier for the sample.
138 :return: The sample identifier."""
139 return self._sid
141 @property
142 def waveforms(self) -> dict[int, Waveform]:
143 """A dictionary with the waveforms, where the keys are waveform identifiers and the values are the waveforms.
145 :return: Dictionary with the waveforms for the sample."""
146 return self._waveforms
149ItemSignal: TypeAlias = tuple[float, float, float, str, str, Signal]
150ItemStateLevels: TypeAlias = tuple[float, float, float, Signal, StateLevels]
151ItemEdges: TypeAlias = tuple[float, float, float, Signal, Sequence[Edge]]
152Item: TypeAlias = tuple[Literal["signal", "state_levels", "edges"], ItemSignal | ItemStateLevels | ItemEdges]
155class SampleManager:
156 """Manages the samples in the given root directory."""
158 ## Private API #####################################################################################################
159 def __init__(self, root: str) -> None:
160 """Class constructor.
162 :param root: Root directory where the samples are located.
163 :raise SignalError: The root directory does not exist."""
164 if not all([os.path.exists(root), os.path.isdir(root)]):
165 raise SignalError(f"The path to samples '{root}' does not exist.")
166 #: Path to the root directory with the samples.
167 self._root = os.path.normpath(os.path.realpath(root))
169 ## Protected API ###################################################################################################
170 def _get_spath(self, sid: int) -> str:
171 """Obtains the path to the specified sample.
173 :param sid: The sample identifier.
174 :return: The path to the sample folder."""
175 return os.path.join(self._root, f"sample_{sid:03}")
177 def _get_wpath(self, sid: int, wid: int) -> str:
178 """Obtains the path to a waveform in a sample.
180 :param sid: The sample identifier.
181 :param wid: The waveform identifier.
182 :return: The path to the waveform file."""
183 return os.path.join(self._get_spath(sid), f"waveform_{wid:03}.npz")
185 ## Public API ######################################################################################################
186 @staticmethod
187 def plot(
188 row_0: Sequence[Item],
189 *args,
190 path: str | None = None,
191 mode: sep.Mode = sep.Mode.LINEAR,
192 points: Sequence[Literal["begin", "intermediate", "end"]] = (),
193 levels: Sequence[Literal["highest", "high", "high_runt", "intermediate", "low_runt", "low", "lowest"]] = (),
194 row_1: Sequence[Item] = (),
195 row_2: Sequence[Item] = (),
196 row_3: Sequence[Item] = (),
197 cursors: Sequence[sep.Cursor] = (),
198 **kwargs,
199 ) -> None:
200 """Shortcut for complex plots related to samples, refer to implementation and code snippets for more details.
202 :param row_0: Plot items for the first row.
203 :param args: The arguments to pass to the plotter, see :meth:`.Plotter.plot`.
204 :param path: Path where to save the resulting plot, see :meth:`.Plotter.plot`, or ``None`` for notebooks.
205 :param mode: The mode of operation of the plotter, see :meth:`.Plotter.plot`.
206 :param points: When plotting edges, the point of the edges to plot, defaults to all.
207 :param levels: When plotting state levels, the levels to plot, defaults to all.
208 :param row_1: Plot items for the second row, if any.
209 :param row_2: Plot items for the third row, if any.
210 :param row_3: Plot items for the fourth row, if any.
211 :param cursors: Cursors for the plot.
212 :param kwargs: The keyword arguments to pass to the plotter, see :meth:`.Plotter.plot`.
213 :raise SignalError: An item identifier is invalid or not recognized."""
214 # pylint: disable=too-complex,too-many-locals,redefined-loop-name,too-many-nested-blocks,too-many-branches
215 # pylint: disable=too-many-statements
217 # Calculate the number of rows.
218 rows = 1
219 rows += 1 if len(row_1) > 0 else 0
220 rows += 1 if len(row_2) > 0 else 0
221 rows += 1 if len(row_3) > 0 else 0
223 # Create plotter.
224 plotter = sep.Plotter(mode=mode, rows=rows, columns=1)
226 # Handle signal and edges.
227 for row_i, row in enumerate([row_0, row_1, row_2, row_3]):
228 # If there is no plots in the row, then skip to the next.
229 if len(row) == 0:
230 continue
232 for item_id, item in row:
233 ########################################################################################################
234 if item_id == "signal":
235 # Get relevant data from item.
236 item = cast(ItemSignal, item)
237 (begin, end, munits, name, color, signal) = item
238 (hvalues, vvalues) = (getattr(signal, "_hv"), getattr(signal, "_vv"))
239 (hunits, vunits) = (getattr(signal, "_hunits"), getattr(signal, "_vunits"))
241 # Create subplot for signal.
242 spl = sep.Subplot(name, hvalues, hunits, vvalues, vunits, begin, end, munits, color=color)
244 # Add plot in relevant row and column.
245 plotter.add_plot(row_i, 0, spl)
246 ########################################################################################################
247 elif item_id == "state_levels":
248 # Get relevant data from item.
249 item = cast(ItemStateLevels, item)
250 (begin, end, munits, signal, state_levels) = item
251 (hunits, vunits) = (getattr(signal, "_hunits"), getattr(signal, "_vunits"))
252 state_levels_to_array = getattr(signal, "state_levels_to_array")
254 # Create subplot for signal.
255 for level in ("highest", "high", "high_runt", "intermediate", "low_runt", "low", "lowest"):
256 levels_dict = {
257 "highest": "Highest State Level",
258 "high": "High State Level",
259 "high_runt": "High State Level (Runt)",
260 "intermediate": "Intermediate State Level",
261 "low_runt": "Low State Level (Runt)",
262 "low": "Low State Level",
263 "lowest": "Lowest State Level",
264 }
266 if len(levels) == 0 or level in levels:
267 (level_x, level_y) = state_levels_to_array(state_levels, level)
268 subplot = sep.Subplot(
269 levels_dict[level],
270 level_x,
271 hunits,
272 level_y,
273 vunits,
274 begin,
275 end,
276 munits,
277 "#7F7F7F",
278 linestyle="dotted",
279 marker="none",
280 )
281 plotter.add_plot(row_i, 0, subplot)
282 ########################################################################################################
283 elif item_id == "edges":
284 # Get relevant data from item.
285 item = cast(ItemEdges, item)
286 (begin, end, munits, signal, edges) = item
287 (hunits, vunits) = (getattr(signal, "_hunits"), getattr(signal, "_vunits"))
288 edges_to_array = getattr(signal, "edges_to_array")
290 # If there are any edges plot them.
291 if len(edges) > 0:
292 for point in ("begin", "intermediate", "end"):
293 points_dict = {
294 "begin": ("Begin Edge Point", ">"),
295 "intermediate": ("Intermediate Edge Point", "8"),
296 "end": ("End Edge Point", "<"),
297 }
299 if len(points) == 0 or point in points:
300 (edges_x, edges_y) = edges_to_array(edges, point)
302 # For intermediate points, the values can be repeated for edges that share the
303 # same point, fetch unique values for plotting.
304 if point == "intermediate":
305 unique_indices = np.unique(edges_x, return_index=True)[1]
306 (edges_x, edges_y) = (edges_x[unique_indices], edges_y[unique_indices])
308 subplot = sep.Subplot(
309 points_dict[point][0],
310 edges_x,
311 hunits,
312 edges_y,
313 vunits,
314 begin,
315 end,
316 munits,
317 "white",
318 linestyle="none",
319 marker=points_dict[point][1],
320 )
321 plotter.add_plot(row_i, 0, subplot)
322 ########################################################################################################
323 else:
324 raise SignalError(f"Invalid item identifier '{item_id}' for item in plot.")
326 # Handle cursors.
327 for cursor in cursors:
328 plotter.add_cursor(cursor)
330 # Run plotter if plotting to file, otherwise plot and display.
331 if path is None:
332 # Check if relevant module was imported.
333 if "IPython" not in sys.modules:
334 raise SignalError("Can't plot to notebook because IPython was not imported.")
336 # Plot to temporary file, display it in notebook and then delete it.
337 with tempfile.NamedTemporaryFile("w+", suffix=".png", delete=False) as file:
338 plotter.plot(i := os.path.normpath(file.name), *args, **kwargs)
339 IPython.display.display(IPython.display.Image(i)) # type: ignore
340 os.unlink(i)
341 else:
342 plotter.plot(path, *args, **kwargs)
344 def get_sids(self) -> tuple[int, ...]:
345 """Obtains the existing sample identifiers in the root folder.
347 :return: The sample identifiers."""
348 return tuple(
349 int(match.groups()[0])
350 for i in os.listdir(self._root)
351 if os.path.isdir(os.path.join(self._root, i))
352 and ((match := re.match(r"sample_([0-9]{3})$", i)) is not None)
353 )
355 def get_wids(self, sid: int) -> tuple[int, ...]:
356 """Obtains the waveform identifiers of a sample in the root folder.
358 :param sid: The sample identifier.
359 :return: The waveform identifiers for the sample."""
360 return tuple(
361 int(match.groups()[0])
362 for i in os.listdir(self._get_spath(sid))
363 if os.path.isfile(os.path.join(self._get_spath(sid), i))
364 and ((match := re.match(r"waveform_([0-9]{3}).npz$", i)) is not None)
365 )
367 def new(self, sid: int, waveforms: Sequence[dict[str, Any]], overwrite: bool = False) -> Sample:
368 """Creates a new sample with the data for the waveforms provided in the format below:
370 .. code-block:: json
372 {
373 "wid": "An integer with the waveform identifier",
374 "vvalues": "Numpy array with values for the vertical axis of a signal",
375 "hvalues": "Numpy array with values for the horizontal axis of a signal"
376 }
378 Any other value in the waveform dictionary is stored in the `.npz` file as metadata.
380 :param sid: The sample identifier of the new sample.
381 :param waveforms: The waveforms for the sample following the format described above for each.
382 :param overwrite: Overwrite the sample if it exists, if ``False`` then raise exception instead.
383 :raise SignalError: The specified sample already exists and ``overwrite`` was set to ``False``.
384 :raise SignalError: At least one of the waveforms is not in the correct format.
385 :return: The sample created."""
386 # Get path for sample, and check if it exists, handling overwrite.
387 if os.path.exists(spath := self._get_spath(sid)):
388 if not overwrite:
389 raise SignalError(f"Sample at path '{spath}' already exists.")
390 shutil.rmtree(spath)
391 # Create directory structure.
392 os.makedirs(spath)
394 # Store each waveform.
395 for _, wfm in enumerate(waveforms):
396 # Ensure mandatory keys exist.
397 if not all(i in wfm.keys() for i in ("wid", "vvalues", "hvalues")):
398 raise SignalError("At least one of the mandatory keys for the waveform dictionary is missing.")
399 # Save to file.
400 np.savez_compressed(self._get_wpath(sid, wfm["wid"]), **{k: wfm[k] for k in wfm.keys() if k != "wid"})
402 # Return loaded sample.
403 return self.load(sid)
405 def load(self, sid: int) -> Sample:
406 """Loads the specified sample from the root folder.
408 :param sid: The sample identifier.
409 :raise SignalError: The sample specified does not exist.
410 :return: The sample with its waveforms."""
411 # Get path to sample.
412 if not os.path.exists(spath := self._get_spath(sid)):
413 raise SignalError(f"The sample at '{spath}' does not exist.")
414 # Return loaded sample.
415 return Sample(sid, [Waveform(sid, wid, self._get_wpath(sid, wid)) for wid in self.get_wids(sid)])
417 def save(self, sid: int, sample: Sample, overwrite: bool = False) -> Sample:
418 """Saves the save to the root folder, overwriting a previous sample if it exists.
420 :param sid: The sample identifier.
421 :param sample: The sample to save.
422 :param overwrite: Overwrite the sample if it exists, if ``False`` then raise exception instead.
423 :return: The sample that was just saved."""
424 # Convert sample to waveform dictionaries and create anew.
425 return self.new(
426 sid,
427 [{"wid": k, "hvalues": v.hvalues, "vvalues": v.vvalues, **v.meta} for k, v in sample.waveforms.items()],
428 overwrite,
429 )