Coverage for src/signal_edges/signal/sample/sample.py: 30%

144 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-04-21 11:16 +0000

1"""The sample module introduces the :class:`.Sample` and :class:`.Waveform` classes, a collection of utilities 

2specially intended for analysis of signals in notebooks, although they are also suitable for use outside these. 

3 

4In this context, a `sample` is a collection of `waveforms`, each `sample` is stored in a directory, 

5managed by :class:`.SampleManager`, with the following structure: 

6 

7.. code-block:: 

8 

9 samples 

10 sample_000 

11 waveform_000.npz 

12 waveform_001.npz 

13 ... 

14 sample_001 

15 waveform_000.npz 

16 waveform_001.npz 

17 ... 

18 ... 

19 

20Each waveform is a ``.npz`` file generated with :meth:`numpy.savez_compressed` with additional metadata. 

21 

22Each sample in the ``samples`` directory is identified by an integer number, and similarly each `waveform` in 

23each sample is identified by an integer number. For example notebooks where :class:`.Sample`, :class:`.SampleManager` 

24and :class:`.Waveform` are used refer to the `other/notebook` folders in the repository.""" 

25 

26import os 

27import re 

28import shutil 

29import sys 

30import tempfile 

31from collections.abc import Sequence 

32from typing import Any, Literal, TypeAlias, cast 

33 

34import numpy as np 

35import numpy.typing as npt 

36 

37try: 

38 import IPython 

39 import IPython.display 

40except ImportError: 

41 pass 

42 

43 

44from ... import plotter as sep 

45from ...exceptions import SignalError 

46from ..edges import Edge 

47from ..signal import Signal 

48from ..state_levels import StateLevels 

49 

50 

51class Waveform: 

52 """Utility class to handle a waveforms within :class:`.Sample`.""" 

53 

54 ## Private API ##################################################################################################### 

55 def __init__(self, sid: int, wid: int, path: str) -> None: 

56 """Class constructor. 

57 

58 :param wid: The waveform identifier. 

59 :param path: The path to the waveform file.""" 

60 #: Sample identifier. 

61 self._sid = sid 

62 #: Waveform identifier. 

63 self._wid = wid 

64 

65 self._meta = {} 

66 self._hvalues = np.asarray([1, 2, 3]) 

67 self._vvalues = np.asarray([1, 2, 3]) 

68 

69 # Load waveform file and extract relevant data. 

70 with np.load(path) as data: 

71 #: The values of the horizontal axis. 

72 self._hvalues = data["hvalues"] 

73 #: The values of the vertical axis. 

74 self._vvalues = data["vvalues"] 

75 #: The metadata values. 

76 self._meta = {k: v for k, v in data.items() if k not in ("hvalues", "vvalues")} 

77 

78 ## Protected API ################################################################################################### 

79 

80 ## Public API ###################################################################################################### 

81 @property 

82 def sid(self) -> int: 

83 """The sample identifier of the sample that contains this waveform. 

84 

85 :return: Sample identifier.""" 

86 return self._sid 

87 

88 @property 

89 def wid(self) -> int: 

90 """The waveform identifier within the sample. 

91 

92 :return: Waveform identifier.""" 

93 return self._wid 

94 

95 @property 

96 def hvalues(self) -> npt.NDArray[np.float_]: 

97 """The values of the horizontal axis. 

98 

99 :return: The array with the values.""" 

100 return self._hvalues 

101 

102 @property 

103 def vvalues(self) -> npt.NDArray[np.float_]: 

104 """The values of the vertical axis. 

105 

106 :return: The array with the values.""" 

107 return self._vvalues 

108 

109 @property 

110 def meta(self) -> dict[str, Any]: 

111 """Dictionary with the metadata of the waveform. 

112 

113 :return: Metadata of the waveform.""" 

114 return self._meta 

115 

116 

117class Sample: 

118 """Sample class with its underlying instances of :class:`.Waveform` for each waveform.""" 

119 

120 ## Private API ##################################################################################################### 

121 def __init__(self, sid: int, waveforms: Sequence[Waveform]) -> None: 

122 """Class constructor. 

123 

124 :param sid: The sample identifier. 

125 :param waveforms: The waveforms for the sample.""" 

126 #: Sample identifier. 

127 self._sid = sid 

128 #: The waveforms of the sample. 

129 self._waveforms: dict[int, Waveform] = {wfm.wid: wfm for wfm in waveforms} 

130 

131 ## Protected API ################################################################################################### 

132 

133 ## Public API ###################################################################################################### 

134 @property 

135 def sid(self) -> int: 

136 """The sample identifier for the sample. 

137 

138 :return: The sample identifier.""" 

139 return self._sid 

140 

141 @property 

142 def waveforms(self) -> dict[int, Waveform]: 

143 """A dictionary with the waveforms, where the keys are waveform identifiers and the values are the waveforms. 

144 

145 :return: Dictionary with the waveforms for the sample.""" 

146 return self._waveforms 

147 

148 

149ItemSignal: TypeAlias = tuple[float, float, float, str, str, Signal] 

150ItemStateLevels: TypeAlias = tuple[float, float, float, Signal, StateLevels] 

151ItemEdges: TypeAlias = tuple[float, float, float, Signal, Sequence[Edge]] 

152Item: TypeAlias = tuple[Literal["signal", "state_levels", "edges"], ItemSignal | ItemStateLevels | ItemEdges] 

153 

154 

155class SampleManager: 

156 """Manages the samples in the given root directory.""" 

157 

158 ## Private API ##################################################################################################### 

159 def __init__(self, root: str) -> None: 

160 """Class constructor. 

161 

162 :param root: Root directory where the samples are located. 

163 :raise SignalError: The root directory does not exist.""" 

164 if not all([os.path.exists(root), os.path.isdir(root)]): 

165 raise SignalError(f"The path to samples '{root}' does not exist.") 

166 #: Path to the root directory with the samples. 

167 self._root = os.path.normpath(os.path.realpath(root)) 

168 

169 ## Protected API ################################################################################################### 

170 def _get_spath(self, sid: int) -> str: 

171 """Obtains the path to the specified sample. 

172 

173 :param sid: The sample identifier. 

174 :return: The path to the sample folder.""" 

175 return os.path.join(self._root, f"sample_{sid:03}") 

176 

177 def _get_wpath(self, sid: int, wid: int) -> str: 

178 """Obtains the path to a waveform in a sample. 

179 

180 :param sid: The sample identifier. 

181 :param wid: The waveform identifier. 

182 :return: The path to the waveform file.""" 

183 return os.path.join(self._get_spath(sid), f"waveform_{wid:03}.npz") 

184 

185 ## Public API ###################################################################################################### 

186 @staticmethod 

187 def plot( 

188 row_0: Sequence[Item], 

189 *args, 

190 path: str | None = None, 

191 mode: sep.Mode = sep.Mode.LINEAR, 

192 points: Sequence[Literal["begin", "intermediate", "end"]] = (), 

193 levels: Sequence[Literal["highest", "high", "high_runt", "intermediate", "low_runt", "low", "lowest"]] = (), 

194 row_1: Sequence[Item] = (), 

195 row_2: Sequence[Item] = (), 

196 row_3: Sequence[Item] = (), 

197 cursors: Sequence[sep.Cursor] = (), 

198 **kwargs, 

199 ) -> None: 

200 """Shortcut for complex plots related to samples, refer to implementation and code snippets for more details. 

201 

202 :param row_0: Plot items for the first row. 

203 :param args: The arguments to pass to the plotter, see :meth:`.Plotter.plot`. 

204 :param path: Path where to save the resulting plot, see :meth:`.Plotter.plot`, or ``None`` for notebooks. 

205 :param mode: The mode of operation of the plotter, see :meth:`.Plotter.plot`. 

206 :param points: When plotting edges, the point of the edges to plot, defaults to all. 

207 :param levels: When plotting state levels, the levels to plot, defaults to all. 

208 :param row_1: Plot items for the second row, if any. 

209 :param row_2: Plot items for the third row, if any. 

210 :param row_3: Plot items for the fourth row, if any. 

211 :param cursors: Cursors for the plot. 

212 :param kwargs: The keyword arguments to pass to the plotter, see :meth:`.Plotter.plot`. 

213 :raise SignalError: An item identifier is invalid or not recognized.""" 

214 # pylint: disable=too-complex,too-many-locals,redefined-loop-name,too-many-nested-blocks,too-many-branches 

215 # pylint: disable=too-many-statements 

216 

217 # Calculate the number of rows. 

218 rows = 1 

219 rows += 1 if len(row_1) > 0 else 0 

220 rows += 1 if len(row_2) > 0 else 0 

221 rows += 1 if len(row_3) > 0 else 0 

222 

223 # Create plotter. 

224 plotter = sep.Plotter(mode=mode, rows=rows, columns=1) 

225 

226 # Handle signal and edges. 

227 for row_i, row in enumerate([row_0, row_1, row_2, row_3]): 

228 # If there is no plots in the row, then skip to the next. 

229 if len(row) == 0: 

230 continue 

231 

232 for item_id, item in row: 

233 ######################################################################################################## 

234 if item_id == "signal": 

235 # Get relevant data from item. 

236 item = cast(ItemSignal, item) 

237 (begin, end, munits, name, color, signal) = item 

238 (hvalues, vvalues) = (getattr(signal, "_hv"), getattr(signal, "_vv")) 

239 (hunits, vunits) = (getattr(signal, "_hunits"), getattr(signal, "_vunits")) 

240 

241 # Create subplot for signal. 

242 spl = sep.Subplot(name, hvalues, hunits, vvalues, vunits, begin, end, munits, color=color) 

243 

244 # Add plot in relevant row and column. 

245 plotter.add_plot(row_i, 0, spl) 

246 ######################################################################################################## 

247 elif item_id == "state_levels": 

248 # Get relevant data from item. 

249 item = cast(ItemStateLevels, item) 

250 (begin, end, munits, signal, state_levels) = item 

251 (hunits, vunits) = (getattr(signal, "_hunits"), getattr(signal, "_vunits")) 

252 state_levels_to_array = getattr(signal, "state_levels_to_array") 

253 

254 # Create subplot for signal. 

255 for level in ("highest", "high", "high_runt", "intermediate", "low_runt", "low", "lowest"): 

256 levels_dict = { 

257 "highest": "Highest State Level", 

258 "high": "High State Level", 

259 "high_runt": "High State Level (Runt)", 

260 "intermediate": "Intermediate State Level", 

261 "low_runt": "Low State Level (Runt)", 

262 "low": "Low State Level", 

263 "lowest": "Lowest State Level", 

264 } 

265 

266 if len(levels) == 0 or level in levels: 

267 (level_x, level_y) = state_levels_to_array(state_levels, level) 

268 subplot = sep.Subplot( 

269 levels_dict[level], 

270 level_x, 

271 hunits, 

272 level_y, 

273 vunits, 

274 begin, 

275 end, 

276 munits, 

277 "#7F7F7F", 

278 linestyle="dotted", 

279 marker="none", 

280 ) 

281 plotter.add_plot(row_i, 0, subplot) 

282 ######################################################################################################## 

283 elif item_id == "edges": 

284 # Get relevant data from item. 

285 item = cast(ItemEdges, item) 

286 (begin, end, munits, signal, edges) = item 

287 (hunits, vunits) = (getattr(signal, "_hunits"), getattr(signal, "_vunits")) 

288 edges_to_array = getattr(signal, "edges_to_array") 

289 

290 # If there are any edges plot them. 

291 if len(edges) > 0: 

292 for point in ("begin", "intermediate", "end"): 

293 points_dict = { 

294 "begin": ("Begin Edge Point", ">"), 

295 "intermediate": ("Intermediate Edge Point", "8"), 

296 "end": ("End Edge Point", "<"), 

297 } 

298 

299 if len(points) == 0 or point in points: 

300 (edges_x, edges_y) = edges_to_array(edges, point) 

301 

302 # For intermediate points, the values can be repeated for edges that share the 

303 # same point, fetch unique values for plotting. 

304 if point == "intermediate": 

305 unique_indices = np.unique(edges_x, return_index=True)[1] 

306 (edges_x, edges_y) = (edges_x[unique_indices], edges_y[unique_indices]) 

307 

308 subplot = sep.Subplot( 

309 points_dict[point][0], 

310 edges_x, 

311 hunits, 

312 edges_y, 

313 vunits, 

314 begin, 

315 end, 

316 munits, 

317 "white", 

318 linestyle="none", 

319 marker=points_dict[point][1], 

320 ) 

321 plotter.add_plot(row_i, 0, subplot) 

322 ######################################################################################################## 

323 else: 

324 raise SignalError(f"Invalid item identifier '{item_id}' for item in plot.") 

325 

326 # Handle cursors. 

327 for cursor in cursors: 

328 plotter.add_cursor(cursor) 

329 

330 # Run plotter if plotting to file, otherwise plot and display. 

331 if path is None: 

332 # Check if relevant module was imported. 

333 if "IPython" not in sys.modules: 

334 raise SignalError("Can't plot to notebook because IPython was not imported.") 

335 

336 # Plot to temporary file, display it in notebook and then delete it. 

337 with tempfile.NamedTemporaryFile("w+", suffix=".png", delete=False) as file: 

338 plotter.plot(i := os.path.normpath(file.name), *args, **kwargs) 

339 IPython.display.display(IPython.display.Image(i)) # type: ignore 

340 os.unlink(i) 

341 else: 

342 plotter.plot(path, *args, **kwargs) 

343 

344 def get_sids(self) -> tuple[int, ...]: 

345 """Obtains the existing sample identifiers in the root folder. 

346 

347 :return: The sample identifiers.""" 

348 return tuple( 

349 int(match.groups()[0]) 

350 for i in os.listdir(self._root) 

351 if os.path.isdir(os.path.join(self._root, i)) 

352 and ((match := re.match(r"sample_([0-9]{3})$", i)) is not None) 

353 ) 

354 

355 def get_wids(self, sid: int) -> tuple[int, ...]: 

356 """Obtains the waveform identifiers of a sample in the root folder. 

357 

358 :param sid: The sample identifier. 

359 :return: The waveform identifiers for the sample.""" 

360 return tuple( 

361 int(match.groups()[0]) 

362 for i in os.listdir(self._get_spath(sid)) 

363 if os.path.isfile(os.path.join(self._get_spath(sid), i)) 

364 and ((match := re.match(r"waveform_([0-9]{3}).npz$", i)) is not None) 

365 ) 

366 

367 def new(self, sid: int, waveforms: Sequence[dict[str, Any]], overwrite: bool = False) -> Sample: 

368 """Creates a new sample with the data for the waveforms provided in the format below: 

369 

370 .. code-block:: json 

371 

372 { 

373 "wid": "An integer with the waveform identifier", 

374 "vvalues": "Numpy array with values for the vertical axis of a signal", 

375 "hvalues": "Numpy array with values for the horizontal axis of a signal" 

376 } 

377 

378 Any other value in the waveform dictionary is stored in the `.npz` file as metadata. 

379 

380 :param sid: The sample identifier of the new sample. 

381 :param waveforms: The waveforms for the sample following the format described above for each. 

382 :param overwrite: Overwrite the sample if it exists, if ``False`` then raise exception instead. 

383 :raise SignalError: The specified sample already exists and ``overwrite`` was set to ``False``. 

384 :raise SignalError: At least one of the waveforms is not in the correct format. 

385 :return: The sample created.""" 

386 # Get path for sample, and check if it exists, handling overwrite. 

387 if os.path.exists(spath := self._get_spath(sid)): 

388 if not overwrite: 

389 raise SignalError(f"Sample at path '{spath}' already exists.") 

390 shutil.rmtree(spath) 

391 # Create directory structure. 

392 os.makedirs(spath) 

393 

394 # Store each waveform. 

395 for _, wfm in enumerate(waveforms): 

396 # Ensure mandatory keys exist. 

397 if not all(i in wfm.keys() for i in ("wid", "vvalues", "hvalues")): 

398 raise SignalError("At least one of the mandatory keys for the waveform dictionary is missing.") 

399 # Save to file. 

400 np.savez_compressed(self._get_wpath(sid, wfm["wid"]), **{k: wfm[k] for k in wfm.keys() if k != "wid"}) 

401 

402 # Return loaded sample. 

403 return self.load(sid) 

404 

405 def load(self, sid: int) -> Sample: 

406 """Loads the specified sample from the root folder. 

407 

408 :param sid: The sample identifier. 

409 :raise SignalError: The sample specified does not exist. 

410 :return: The sample with its waveforms.""" 

411 # Get path to sample. 

412 if not os.path.exists(spath := self._get_spath(sid)): 

413 raise SignalError(f"The sample at '{spath}' does not exist.") 

414 # Return loaded sample. 

415 return Sample(sid, [Waveform(sid, wid, self._get_wpath(sid, wid)) for wid in self.get_wids(sid)]) 

416 

417 def save(self, sid: int, sample: Sample, overwrite: bool = False) -> Sample: 

418 """Saves the save to the root folder, overwriting a previous sample if it exists. 

419 

420 :param sid: The sample identifier. 

421 :param sample: The sample to save. 

422 :param overwrite: Overwrite the sample if it exists, if ``False`` then raise exception instead. 

423 :return: The sample that was just saved.""" 

424 # Convert sample to waveform dictionaries and create anew. 

425 return self.new( 

426 sid, 

427 [{"wid": k, "hvalues": v.hvalues, "vvalues": v.vvalues, **v.meta} for k, v in sample.waveforms.items()], 

428 overwrite, 

429 )