Source code for setigen.split_utils

from __future__ import annotations

import os
import errno
from pathlib import Path
import numpy as np
from blimpy import Waterfall
from typing import Iterator

from ._typing import PathLike


[docs] def split_waterfall_generator( waterfall_fn: PathLike, fchans: int, tchans: int | None = None, f_shift: int | None = None, ) -> Iterator[Waterfall]: """Yield smaller waterfall views split from a larger filterbank file. Args: waterfall_fn: Input filterbank filename. fchans: Number of frequency samples per split. tchans: Optional number of time samples to include. f_shift: Optional shift in frequency bins between splits. Yields: Waterfall views covering smaller sections of the input. Raises: ValueError: If `tchans` exceeds the available number of time samples. """ info_wf = Waterfall(waterfall_fn, load_data=False) fch1 = info_wf.header['fch1'] nchans = info_wf.header['nchans'] df = info_wf.header['foff'] tchans_tot = info_wf.container.selection_shape[0] if f_shift is None: f_shift = fchans if tchans is None: tchans = tchans_tot elif tchans > tchans_tot: raise ValueError('tchans value must be less than the total number of \ time samples in the observation') # Note that df is negative! f_start, f_stop = fch1, fch1 + fchans * df # Iterates down frequencies, starting from highest while np.abs(f_stop - fch1) <= np.abs(nchans * df): fmin, fmax = np.sort([f_start, f_stop]) waterfall = Waterfall(waterfall_fn, f_start=fmin, f_stop=fmax, t_start=0, t_stop=tchans) yield waterfall f_start += f_shift * df f_stop += f_shift * df
[docs] def split_fil( waterfall_fn: PathLike, output_dir: PathLike, fchans: int, tchans: int | None = None, f_shift: int | None = None, ) -> list[Path]: """Split a filterbank file into smaller `.fil` files. Args: waterfall_fn: Input filterbank filename. output_dir: Directory for the new filterbank files. fchans: Number of frequency samples per split file. tchans: Optional number of time samples to include. f_shift: Optional shift in frequency bins between splits. Returns: Paths to the new filterbank files. """ output_dir = Path(output_dir) try: os.makedirs(output_dir) except OSError as e: if e.errno != errno.EEXIST: raise split_generator = split_waterfall_generator(waterfall_fn, fchans, tchans=tchans, f_shift=f_shift) # Iterates down frequencies, starting from highest split_fns = [] for i, waterfall in enumerate(split_generator): output_fn = output_dir / f"{fchans}_{i:04d}.fil" waterfall.write_to_fil(output_fn) split_fns.append(output_fn) print(f"Saved {output_fn}") return split_fns
[docs] def split_array(data: np.ndarray, f_sample_num: int | None = None, t_sample_num: int | None = None, f_shift: int | None = None, t_shift: int | None = None, f_trim: bool = False, t_trim: bool = False) -> np.ndarray: """Split an array into smaller frequency-time windows. Args: data: Two-dimensional time-frequency data array. f_sample_num: Number of frequency samples per split. t_sample_num: Number of time samples per split. f_shift: Shift in frequency bins between splits. t_shift: Shift in time bins between splits. f_trim: Whether to drop splits with incomplete frequency width. t_trim: Whether to drop splits with incomplete time height. Returns: Array of split time-frequency windows. Raises: ValueError: If the input is not an ndarray or an invalid shift is supplied. """ split_data = [] if not isinstance(data, np.ndarray): raise ValueError("Input data must be a numpy array") height, width = data.shape if f_sample_num is None: f_sample_num = width if t_sample_num is None: t_sample_num = height if f_shift is None: f_shift = f_sample_num elif f_shift <= 0: raise ValueError(f"Invalid x-direction shift: {f_shift}") if t_shift is None: t_shift = t_sample_num elif t_shift <= 0: raise ValueError(f"Invalid y-direction shift: {t_shift}") # Save first frame, regardless of overstepping bounds y_start = 0 y_stop = min(t_sample_num, height) x_start = 0 x_stop = min(f_sample_num, width) split_data.append(data[y_start:y_stop, x_start:x_stop]) y_in_bound = (y_stop < height) x_in_bound = (x_stop < width) # As long as either bound is valid, continue adding frames while y_in_bound or x_in_bound: # Shift frames in the x direction while x_in_bound: x_start = x_start + f_shift x_stop = min(x_stop + f_shift, width) split_data.append(data[y_start:y_stop, x_start:x_stop]) x_in_bound = (x_stop < width) # Break when both y and x are out of bounds if not y_in_bound: break # Shift frames in the y direction and reset x indices y_start = y_start + t_shift y_stop = min(y_stop + t_shift, height) x_start = 0 x_stop = min(f_sample_num, width) split_data.append(data[y_start:y_stop, x_start:x_stop]) y_in_bound = (y_stop < height) x_in_bound = (x_stop < width) # Filter out frames that aren't the same specied size if t_trim: split_data = list(filter(lambda A: A.shape[0] == t_sample_num, split_data)) if f_trim: split_data = list(filter(lambda A: A.shape[1] == f_sample_num, split_data)) return np.array(split_data)