diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 167b3ee..446e632 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,34 +14,34 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - python-version: ['3.8', '3.9'] + python-version: ['3.8', '3.9', '3.10'] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install it run: | python -m pip install -U pip setuptools wheel - python -m pip install -e .[tests] - - name: Lint with flake8 + python -m pip install .[tests] -v + - name: Lint with Ruff + uses: chartboost/ruff-action@v1 + with: + args: "check" + - name: Check types with mypy run: | - pip install flake8 wemake-python-styleguide - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --statistics - - name: Run pytest and Generate coverage report + mypy --strict ./sigpyproc/ + - name: Test with pytest and Generate coverage report run: | - pip install pytest pytest-cov pytest --cov=./ --cov-report=xml continue-on-error: false - name: Upload coverage to Codecov - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v4 with: - token: ${{ secrets.CODECOV_TOKEN }} file: ./coverage.xml name: codecov-umbrella fail_ci_if_error: true + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index f8568b2..3af2fd9 100644 --- a/.gitignore +++ b/.gitignore @@ -69,6 +69,9 @@ docs/jupyter_execute/ .dmypy.json dmypy.json +# ruff +.ruff_cache/ + # Mr Developer .mr.developer.cfg .project diff --git a/pyproject.toml b/pyproject.toml index c7f4121..3e8b107 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,35 +1,99 @@ [build-system] -requires = [ - "setuptools>=47", - "wheel", +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "sigpyproc" +version = "1.1.0" +description = "Python FRB/pulsar data toolbox" +readme = "README.md" +authors = [{ name = "Ewan Barr", email = "ewan.d.barr@gmail.com" }] +maintainers = [{ name = "Pravir Kumar", email = "pravirka@gmail.com" }] +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX :: Linux", + "License :: OSI Approved :: MIT License", + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Astronomy", +] +dependencies = [ + "numpy", + "numba", + "astropy", + "h5py", + "bottleneck", + "attrs", + "click", + "rich", + "bidict", + "typing_extensions", ] -build-backend = "setuptools.build_meta" +[project.urls] +Repository = "https://github.com/FRBs/sigpyproc3" +[project.optional-dependencies] +tests = [ + "pytest", + "pytest-cov", + "pytest-randomly", + "pytest-sugar", + "pytest-benchmark", + "mypy", +] +docs = ["sphinx", "sphinx-book-theme", "sphinx-click", "myst-nb"] +develop = ["ruff"] -[tool.black] -line-length = 90 -target_version = ['py38', 'py39', 'py310'] +[project.scripts] +spp_header = "sigpyproc.apps.spp_header:main" +spp_decimate = "sigpyproc.apps.spp_decimate:main" +spp_extract = "sigpyproc.apps.spp_extract:main" +spp_clean = "sigpyproc.apps.spp_clean:main" -[tool.pytest.ini_options] -minversion = "6.0" -testpaths = [ - "tests", -] +[tool.ruff] +include = ["pyproject.toml", "sigpyproc/**/*.py"] +line-length = 88 +indent-width = 4 +target-version = "py39" +[tool.ruff.format] +quote-style = "double" +indent-style = "space" + +[tool.ruff.lint] +select = ["ALL"] +ignore = ["D1", "ANN1", "PLR2004"] + +[tool.ruff.lint.pylint] +max-args = 10 + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + + +[tool.pytest.ini_options] +minversion = "8.0" +testpaths = "tests" [tool.coverage.paths] source = ["./sigpyproc/"] [tool.coverage.run] omit = [ - '*tests*', - '*docs*', - '*apps*', - '*setup.py', - '*__init__.py', - '*sigpyproc/core/kernels.py', + "tests/*", + "docs/*", + "apps/*", + "*__init__.py", + "sigpyproc/core/kernels.py", ] [tool.coverage.report] @@ -38,7 +102,6 @@ show_missing = true ignore_errors = true exclude_lines = ['raise AssertionError', 'raise NotImplementedError'] - [tool.mypy] ignore_missing_imports = true plugins = ["numpy.typing.mypy_plugin"] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 7673253..0000000 --- a/setup.cfg +++ /dev/null @@ -1,123 +0,0 @@ -[metadata] -name = sigpyproc -version = 1.1.0 -author = Ewan Barr -author_email = ewan.d.barr@gmail.com -maintainer = Pravir Kumar -maintainer_email = pravirka@gmail.com -url = https://github.com/FRBs/sigpyproc3 -description = Python FRB/pulsar data toolbox -long_description = file: README.md -long_description_content_type = text/markdown -classifiers = - Operating System :: OS Independent - Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Intended Audience :: Science/Research - Topic :: Scientific/Engineering :: Astronomy - -[options] -zip_safe = false -include_package_data = false -packages = find: -python_requires = >=3.8 -install_requires = - numpy>=1.20 - numba - astropy>=4.0 - h5py - scipy - bottleneck - attrs - click - rich - bidict - iqrm - typing_extensions - -[options.entry_points] -console_scripts = - spp_header = sigpyproc.apps.spp_header:main - spp_decimate = sigpyproc.apps.spp_decimate:main - spp_extract = sigpyproc.apps.spp_extract:main - spp_clean = sigpyproc.apps.spp_clean:main - -[options.extras_require] -tests = - pytest - pytest-cov - pytest-randomly - pytest-sugar - pytest-benchmark - pytest-mock -docs = - sphinx>=4.4.0 - sphinx-book-theme>=0.2.0 - myst-nb - sphinx-click - -develop = - wemake-python-styleguide - black - mypy - -[flake8] -ignore = - # Default ignore - BLK100, - # Line break - W503, - # Isort - I, - # Docs - D10, D401 - # Trailing commas - C81, - # Quotes - Q0, - # WPS - WPS100, WPS110, WPS114 - WPS210, WPS216, WPS220, WPS237, - WPS305, WPS323, WPS338, WPS339, WPS345, WPS352, WPS362 - WPS420, WPS432, WPS440, WPS441, - WPS519, - WPS602, - # bandit - S101, S105, S404, S602, S607, - # whitespace before colon - E203 - # Numpy style RST - RST210, DAR401 -exclude = .git, .eggs, __pycache__, docs/, old/, build/, dist/ -i-control-code = False -max-asserts = 10 -max-imports = 20 -max-methods = 35 -max-attributes = 30 -max-arguments = 20 -max-raises = 5 -max-complexity = 10 -max-expressions = 20 -max-string-usages = 50 -max-cognitive-score = 20 -max-line-complexity = 40 -max-module-members = 50 -max-module-expressions = 20 -max-function-expressions = 10 -max-local-variables = 10 -max-line-length = 127 -rst-roles = - attr,class,func,meth,mod,obj,ref,term, - # C programming language: - c:member, - # Python programming language: - py:func,py:mod,py:obj -per-file-ignores = - tests/test_fileio.py: WPS437, E712 - - -[darglint] -docstring_style=numpy -ignore=DAR402,DAR103,DAR201,DAR101 diff --git a/setup.py b/setup.py deleted file mode 100644 index a4f49f9..0000000 --- a/setup.py +++ /dev/null @@ -1,2 +0,0 @@ -import setuptools -setuptools.setup() diff --git a/sigpyproc/apps/spp_clean.py b/sigpyproc/apps/spp_clean.py index b97f408..0a61cfa 100644 --- a/sigpyproc/apps/spp_clean.py +++ b/sigpyproc/apps/spp_clean.py @@ -4,7 +4,7 @@ @click.command( - context_settings=dict(help_option_names=["-h", "--help"], show_default=True) + context_settings={"help_option_names": ["-h", "--help"], "show_default": True}, ) @click.argument("filfile", type=click.Path(exists=True)) @click.option( @@ -15,24 +15,45 @@ help="RFI cleaning method to use.", ) @click.option( - "-t", "--threshold", type=float, default=3.0, help="Sigma threshold for RFI cleaning." + "-t", + "--threshold", + type=float, + default=3.0, + help="Sigma threshold for RFI cleaning.", ) @click.option( - "-g", "--gulp", type=int, default=16384, help="Number of samples to read at once" + "-g", + "--gulp", + type=int, + default=16384, + help="Number of samples to read at once", ) @click.option( - "-o", "--outfile", type=click.Path(exists=False), help="Output masked filterbank file" + "-o", + "--outfile", + type=click.Path(exists=False), + help="Output masked filterbank file", ) @click.option( - "--save/--no-save", default=True, help="Save the mask information to a file" + "--save/--no-save", + default=True, + help="Save the mask information to a file", ) def main( - filfile: str, method: str, threshold: float, outfile: str, gulp: int, save: bool + filfile: str, + method: str, + threshold: float, + outfile: str, + gulp: int, + save: bool, # noqa: FBT001 ) -> None: """Clean RFI from filterbank data.""" fil = FilReader(filfile) _out_file, rfimask = fil.clean_rfi( - method=method, threshold=threshold, filename=outfile, gulp=gulp + method=method, + threshold=threshold, + filename=outfile, + gulp=gulp, ) if save: rfimask.to_file() diff --git a/sigpyproc/apps/spp_decimate.py b/sigpyproc/apps/spp_decimate.py index 268f710..2d80c15 100644 --- a/sigpyproc/apps/spp_decimate.py +++ b/sigpyproc/apps/spp_decimate.py @@ -4,20 +4,36 @@ @click.command( - context_settings=dict(help_option_names=["-h", "--help"], show_default=True) + context_settings={"help_option_names": ["-h", "--help"], "show_default": True}, ) @click.argument("filfile", type=click.Path(exists=True)) @click.option( - "-t", "--tfactor", type=int, default=1, help="Number of time samples to add" + "-t", + "--tfactor", + type=int, + default=1, + help="Number of time samples to add", ) @click.option( - "-c", "--ffactor", type=int, default=1, help="Number of frequency channels to add" + "-c", + "--ffactor", + type=int, + default=1, + help="Number of frequency channels to add", ) @click.option( - "-g", "--gulp", type=int, default=16384, help="Number of samples to read at once" + "-g", + "--gulp", + type=int, + default=16384, + help="Number of samples to read at once", ) @click.option( - "-o", "--outfile", type=click.Path(exists=False), default=None, help="Output filename" + "-o", + "--outfile", + type=click.Path(exists=False), + default=None, + help="Output filename", ) def main(filfile: str, tfactor: int, ffactor: int, gulp: int, outfile: str) -> None: """Reduce time and/or frequency resolution of filterbank data.""" diff --git a/sigpyproc/apps/spp_extract.py b/sigpyproc/apps/spp_extract.py index 8232802..4450b62 100644 --- a/sigpyproc/apps/spp_extract.py +++ b/sigpyproc/apps/spp_extract.py @@ -1,8 +1,12 @@ import click +import numpy as np + from sigpyproc.readers import FilReader -@click.group(context_settings=dict(help_option_names=["-h", "--help"], show_default=True)) +@click.group( + context_settings={"help_option_names": ["-h", "--help"], "show_default": True}, +) def main() -> None: pass @@ -11,13 +15,25 @@ def main() -> None: @click.argument("filfile", type=click.Path(exists=True)) @click.option("-s", "--start", type=int, required=True, help="Start time sample") @click.option( - "-n", "--nsamps", type=int, required=True, help="Number of time samples to extract" + "-n", + "--nsamps", + type=int, + required=True, + help="Number of time samples to extract", ) @click.option( - "-g", "--gulp", type=int, default=16384, help="Number of samples to read at once" + "-g", + "--gulp", + type=int, + default=16384, + help="Number of samples to read at once", ) @click.option( - "-o", "--outfile", type=click.Path(exists=False), default=None, help="Output filename" + "-o", + "--outfile", + type=click.Path(exists=False), + default=None, + help="Output filename", ) def samples(filfile: str, start: int, nsamps: int, gulp: int, outfile: str) -> None: """Extract time samples from filterbank data.""" @@ -34,7 +50,7 @@ def samples(filfile: str, start: int, nsamps: int, gulp: int, outfile: str) -> N multiple=True, help="Channels to extract", ) -def channels(filfile: str, chans: int) -> None: +def channels(filfile: str, chans: np.ndarray) -> None: """Extract frequency channels from filterbank data.""" fil = FilReader(filfile) fil.extract_chans(chans=chans) @@ -44,9 +60,18 @@ def channels(filfile: str, chans: int) -> None: @click.argument("filfile", type=click.Path(exists=True)) @click.option("-s", "--chanstart", type=int, required=True, help="Start channel") @click.option( - "-n", "--nchans", type=int, required=True, help="Number of channels to extract" + "-n", + "--nchans", + type=int, + required=True, + help="Number of channels to extract", +) +@click.option( + "-c", + "--chanpersub", + type=int, + help="Number of channels in each sub-band", ) -@click.option("-c", "--chanpersub", type=int, help="Number of channels in each sub-band") def bands(filfile: str, chanstart: int, nchans: int, chanpersub: int) -> None: """Extract frequency bands from filterbank data.""" fil = FilReader(filfile) diff --git a/sigpyproc/apps/spp_header.py b/sigpyproc/apps/spp_header.py index 15bb486..90353d8 100644 --- a/sigpyproc/apps/spp_header.py +++ b/sigpyproc/apps/spp_header.py @@ -1,18 +1,21 @@ from __future__ import annotations + import click from sigpyproc.header import Header from sigpyproc.io.sigproc import edit_header -@click.group(context_settings=dict(help_option_names=["-h", "--help"], show_default=True)) +@click.group( + context_settings={"help_option_names": ["-h", "--help"], "show_default": True}, +) def main() -> None: pass @main.command() @click.argument("filfile", type=click.Path(exists=True)) -def print(filfile: str) -> None: +def print(filfile: str) -> None: # noqa: A001 """Print the header information.""" header = Header.from_sigproc(filfile) click.echo(header.to_string()) @@ -21,7 +24,10 @@ def print(filfile: str) -> None: @main.command() @click.argument("filfile", type=click.Path(exists=True)) @click.option( - "-k", "--key", type=str, help="A header key to read (e.g. telescope, fch1, nsamples)" + "-k", + "--key", + type=str, + help="A header key to read (e.g. telescope, fch1, nsamples)", ) def get(filfile: str, key: str) -> None: """Get the value of a header key.""" diff --git a/sigpyproc/base.py b/sigpyproc/base.py index 2e777b2..e040bea 100644 --- a/sigpyproc/base.py +++ b/sigpyproc/base.py @@ -1,23 +1,36 @@ from __future__ import annotations -from typing_extensions import Buffer -import warnings -import numpy as np -from typing import Callable -from numpy import typing as npt from abc import ABC, abstractmethod -from collections.abc import Iterator +from typing import TYPE_CHECKING + +import numpy as np +from sigpyproc.core import kernels +from sigpyproc.core.rfi import RFIMask +from sigpyproc.core.stats import ChannelStats from sigpyproc.foldedcube import FoldedData from sigpyproc.timeseries import TimeSeries -from sigpyproc.header import Header -from sigpyproc.block import FilterbankBlock -from sigpyproc.core import stats, kernels -from sigpyproc.core.rfi import RFIMask + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + from typing_extensions import Buffer, TypedDict, Unpack + + from sigpyproc.block import FilterbankBlock + from sigpyproc.header import Header + + class PlanKwargs(TypedDict, total=False): + gulp: int + start: int + nsamps: int | None + skipback: int + description: str | None + quiet: bool + allocator: Callable[[int], Buffer] | None class Filterbank(ABC): - """Base class exporting methods for the manipulation of frequency-major order pulsar data. + """Base class for manipulating frequency-major order pulsar data. Notes ----- @@ -26,17 +39,16 @@ class Filterbank(ABC): """ def __init__(self) -> None: - self._chan_stats: stats.ChannelStats | None = None + self._chan_stats: ChannelStats | None = None @property @abstractmethod def header(self) -> Header: """:class:`~sigpyproc.header.Header`: Header metadata of input file.""" - pass @abstractmethod def read_block(self, start: int, nsamps: int) -> FilterbankBlock: - """Read a block of filterbank data. + """Read a data block from the filterbank file stream. Parameters ---------- @@ -55,7 +67,6 @@ def read_block(self, start: int, nsamps: int) -> FilterbankBlock: ValueError if requested samples are out of range """ - pass @abstractmethod def read_dedisp_block(self, start: int, nsamps: int, dm: float) -> FilterbankBlock: @@ -82,11 +93,11 @@ def read_dedisp_block(self, start: int, nsamps: int, dm: float) -> FilterbankBlo ValueError if requested dedispersed samples are out of range """ - pass @abstractmethod def read_plan( self, + *, gulp: int = 16384, start: int = 0, nsamps: int | None = None, @@ -117,7 +128,7 @@ def read_plan( data to be read into, by default None Yields - ------- + ------ :py:obj:`~collections.abc.Iterator` (tuple(int, int, :py:obj:`~numpy.ndarray`)) Tuple of number of samples read, index of read, and the unpacked data read @@ -142,41 +153,44 @@ def read_plan( # do something where data always has contains ``nchans*nsamps`` points. """ - pass @property - def chan_stats(self) -> stats.ChannelStats | None: - """:class:`~sigpyproc.core.stats.ChannelStats`: Channel statistics of input data.""" + def chan_stats(self) -> ChannelStats | None: + """:class:`~sigpyproc.core.stats.ChannelStats`: Channel statistics.""" return self._chan_stats - def compute_stats(self, **plan_kwargs) -> None: + def compute_stats(self, **plan_kwargs: Unpack[PlanKwargs]) -> None: """Compute channelwise statistics of data (upto kurtosis). Parameters ---------- - **plan_kwargs : dict + **plan_kwargs : Unpack[PlanKwargs] Keyword arguments for :func:`read_plan`. """ - bag = stats.ChannelStats(self.header.nchans, self.header.nsamples) + bag = ChannelStats(self.header.nchans, self.header.nsamples) for nsamps, ii, data in self.read_plan(**plan_kwargs): bag.push_data(data, nsamps, ii, mode="full") self._chan_stats = bag - def compute_stats_basic(self, **plan_kwargs) -> None: + def compute_stats_basic(self, **plan_kwargs: Unpack[PlanKwargs]) -> None: """Compute channelwise statistics of data (only mean and rms). Parameters ---------- - **plan_kwargs : dict + **plan_kwargs : Unpack[PlanKwargs] Keyword arguments for :func:`read_plan`. """ - bag = stats.ChannelStats(self.header.nchans, self.header.nsamples) + bag = ChannelStats(self.header.nchans, self.header.nsamples) for nsamps, ii, data in self.read_plan(**plan_kwargs): bag.push_data(data, nsamps, ii, mode="basic") self._chan_stats = bag def collapse( - self, gulp: int = 16384, start: int = 0, nsamps: int | None = None, **plan_kwargs + self, + gulp: int = 16384, + start: int = 0, + nsamps: int | None = None, + **plan_kwargs: Unpack[PlanKwargs], ) -> TimeSeries: """Sum across all frequencies for each time sample. @@ -199,12 +213,15 @@ def collapse( tim_len = (self.header.nsamples - start) if nsamps is None else nsamps tim_ar = np.zeros(tim_len, dtype=np.float32) for nsamp, ii, data in self.read_plan( - gulp=gulp, start=start, nsamps=nsamps, **plan_kwargs + gulp=gulp, + start=start, + nsamps=nsamps, + **plan_kwargs, ): kernels.extract_tim(data, tim_ar, self.header.nchans, nsamp, ii * gulp) return TimeSeries(tim_ar, self.header.new_header({"nchans": 1, "dm": 0})) - def bandpass(self, **plan_kwargs) -> TimeSeries: + def bandpass(self, **plan_kwargs: Unpack[PlanKwargs]) -> TimeSeries: """Average across each time sample for all frequencies. Parameters @@ -225,7 +242,12 @@ def bandpass(self, **plan_kwargs) -> TimeSeries: bpass_ar /= num_samples return TimeSeries(bpass_ar, self.header.new_header({"nchans": 1})) - def dedisperse(self, dm: float, gulp: int = 16384, **plan_kwargs) -> TimeSeries: + def dedisperse( + self, + dm: float, + gulp: int = 16384, + **plan_kwargs: Unpack[PlanKwargs], + ) -> TimeSeries: """Dedisperse the data and collapse to a time series. Parameters @@ -253,7 +275,9 @@ def dedisperse(self, dm: float, gulp: int = 16384, **plan_kwargs) -> TimeSeries: tim_len = self.header.nsamples - max_delay tim_ar = np.zeros(tim_len, dtype=np.float32) for nsamps, ii, data in self.read_plan( - gulp=gulp, skipback=max_delay, **plan_kwargs + gulp=gulp, + skipback=max_delay, + **plan_kwargs, ): kernels.dedisperse( data, @@ -266,13 +290,18 @@ def dedisperse(self, dm: float, gulp: int = 16384, **plan_kwargs) -> TimeSeries: ) return TimeSeries(tim_ar, self.header.new_header({"nchans": 1, "dm": dm})) - def read_chan(self, chan: int, gulp: int = 16384, **plan_kwargs) -> TimeSeries: + def read_chan( + self, + ichan: int, + gulp: int = 16384, + **plan_kwargs: Unpack[PlanKwargs], + ) -> TimeSeries: """Read a single frequency channel from the data as a time series. Parameters ---------- - chan : int - channel to retrieve (0 is the highest frequency channel) + ichan : int + channel index to retrieve (0 is the highest frequency channel) gulp : int, optional number of samples in each read, by default 16384 **plan_kwargs : dict @@ -286,17 +315,22 @@ def read_chan(self, chan: int, gulp: int = 16384, **plan_kwargs) -> TimeSeries: Raises ------ ValueError - If chan is out of range (chan < 0 or chan > total channels). + If ichan is out of range (ichan < 0 or ichan > nchans). """ - if chan >= self.header.nchans or chan < 0: - raise ValueError("Selected channel out of range.") + if ichan >= self.header.nchans or ichan < 0: + msg = f"Selected channel {ichan} out of range." + raise ValueError(msg) tim_ar = np.empty(self.header.nsamples, dtype=np.float32) for nsamps, ii, data in self.read_plan(gulp=gulp, **plan_kwargs): - data = data.reshape(nsamps, self.header.nchans) - tim_ar[ii * gulp : (ii + 1) * gulp] = data[:, chan] + data_2d = data.reshape(nsamps, self.header.nchans) + tim_ar[ii * gulp : (ii + 1) * gulp] = data_2d[:, ichan] return TimeSeries(tim_ar, self.header.new_header({"dm": 0, "nchans": 1})) - def invert_freq(self, filename: str = None, **plan_kwargs) -> str: + def invert_freq( + self, + filename: str | None = None, + **plan_kwargs: Unpack[PlanKwargs], + ) -> str: """Invert the frequency ordering of the data and write to a new file. Parameters @@ -328,10 +362,10 @@ def invert_freq(self, filename: str = None, **plan_kwargs) -> str: def apply_channel_mask( self, - chanmask: npt.ArrayLike, - maskvalue: int | float = 0, + chanmask: np.ndarray, + maskvalue: float = 0, filename: str | None = None, - **plan_kwargs, + **plan_kwargs: Unpack[PlanKwargs], ) -> str: """Apply a channel mask to the data and write to a new file. @@ -339,7 +373,7 @@ def apply_channel_mask( ---------- chanmask : :py:obj:`~numpy.typing.ArrayLike` boolean array of channel mask (1 or True for bad channel) - maskvalue : int or float, optional + maskvalue : float, optional value to set the masked data to, by default 0 filename : str, optional name of the output filterbank file, by default ``basename_masked.fil`` @@ -368,7 +402,7 @@ def downsample( ffactor: int = 1, gulp: int = 16384, filename: str | None = None, - **plan_kwargs, + **plan_kwargs: Unpack[PlanKwargs], ) -> str: """Downsample data in time and/or frequency and write to file. @@ -398,7 +432,8 @@ def downsample( if filename is None: filename = f"{self.header.basename}_f{ffactor:d}_t{tfactor:d}.fil" if self.header.nchans % ffactor != 0: - raise ValueError("Bad frequency factor given") + msg = f"Bad frequency factor given: {ffactor:d}" + raise ValueError(msg) # Gulp must be a multiple of tfactor gulp = int(np.ceil(gulp / tfactor) * tfactor) @@ -412,13 +447,21 @@ def downsample( for nsamps, _ii, data in self.read_plan(gulp=gulp, **plan_kwargs): write_ar = kernels.downsample_2d( - data, tfactor, ffactor, self.header.nchans, nsamps + data, + tfactor, + ffactor, + self.header.nchans, + nsamps, ) out_file.cwrite(write_ar) return out_file.name def extract_samps( - self, start: int, nsamps: int, filename: str | None = None, **plan_kwargs + self, + start: int, + nsamps: int, + filename: str | None = None, + **plan_kwargs: Unpack[PlanKwargs], ) -> str: """Extract a subset of time samples from the data and write to file. @@ -444,7 +487,8 @@ def extract_samps( If `start` or `nsamps` are out of bounds. """ if start < 0 or start + nsamps > self.header.nsamples: - raise ValueError("Selected samples out of range") + msg = f"Selected samples out of range: {start:d} to {start+nsamps:d}" + raise ValueError(msg) if filename is None: filename = f"{self.header.basename}_samps_{start:d}_{start+nsamps:d}.fil" out_file = self.header.prep_outfile( @@ -452,15 +496,19 @@ def extract_samps( updates={"tstart": self.header.mjd_after_nsamps(start)}, nbits=self.header.nbits, ) - for _count, _ii, data in self.read_plan( - start=start, nsamps=nsamps, **plan_kwargs + for _, _, data in self.read_plan( + start=start, + nsamps=nsamps, + **plan_kwargs, ): out_file.cwrite(data) out_file.close() return out_file.name def extract_chans( - self, chans: npt.ArrayLike | None = None, **plan_kwargs + self, + chans: np.ndarray | None = None, + **plan_kwargs: Unpack[PlanKwargs], ) -> list[str]: """Extract a subset of channels from the data and write each to file. @@ -489,7 +537,8 @@ def extract_chans( chans = np.arange(self.header.nchans) chans = np.array(chans).astype("int") if np.all(np.logical_or(chans >= self.header.nchans, chans < 0)): - raise ValueError("Selected channel out of range.") + msg = f"Selected channels out of range: {chans.min()} to {chans.max()}" + raise ValueError(msg) out_files = [ self.header.prep_outfile( @@ -500,9 +549,9 @@ def extract_chans( for chan in chans ] for nsamps, _ii, data in self.read_plan(**plan_kwargs): - data = data.reshape(nsamps, self.header.nchans) + data_2d = data.reshape(nsamps, self.header.nchans) for ifile, out_file in enumerate(out_files): - out_file.cwrite(data[:, chans[ifile]]) + out_file.cwrite(data_2d[:, chans[ifile]]) for out_file in out_files: out_file.close() @@ -510,7 +559,11 @@ def extract_chans( return [out_file.name for out_file in out_files] def extract_bands( - self, chanstart: int, nchans: int, chanpersub: int | None = None, **plan_kwargs + self, + chanstart: int, + nchans: int, + chanpersub: int | None = None, + **plan_kwargs: Unpack[PlanKwargs], ) -> list[str]: """Extract a subset of Sub-bands from the data and write each to file. @@ -537,7 +590,8 @@ def extract_bands( ValueError If ``nchans`` is not divisible by ``chanpersub``. ValueError - If ``chanstart`` is out of range (``chanstart`` < 0 or ``chanstart`` > total channels). + If ``chanstart`` is out of range (``chanstart`` < 0 + or ``chanstart`` > total channels). Notes ----- @@ -546,11 +600,14 @@ def extract_bands( if chanpersub is None: chanpersub = nchans if chanpersub <= 1 or chanpersub > nchans: - raise ValueError("chanpersub must be > 1 and <= nchans") + msg = f"chanpersub must be > 1 and <= nchans. Got {chanpersub}" + raise ValueError(msg) if chanstart + nchans > self.header.nchans or chanstart < 0: - raise ValueError("Selected channel out of range.") + msg = f"Selected channels out of range: {chanstart} to {chanstart+nchans}" + raise ValueError(msg) if nchans % chanpersub != 0: - raise ValueError("Number of channels must be divisible by sub-band size.") + msg = f"Number of channels must be divisible by sub-band size. Got {nchans}" + raise ValueError(msg) nsub = (self.header.nchans - chanstart) // chanpersub fstart = self.header.fch1 + chanstart * self.header.foff @@ -568,10 +625,10 @@ def extract_bands( ] for nsamps, _ii, data in self.read_plan(**plan_kwargs): - data = data.reshape(nsamps, self.header.nchans) + data_2d = data.reshape(nsamps, self.header.nchans) for ifile, out_file in enumerate(out_files): iband_chanstart = chanstart + ifile * chanpersub - subband_ar = data[:, iband_chanstart : iband_chanstart + chanpersub] + subband_ar = data_2d[:, iband_chanstart : iband_chanstart + chanpersub] out_file.cwrite(subband_ar.ravel()) for out_file in out_files: @@ -582,20 +639,21 @@ def extract_bands( def requantize( self, nbits_out: int, - remove_bandpass: bool = False, filename: str | None = None, - **plan_kwargs, + *, + remove_bandpass: bool = False, # noqa: ARG002 + **plan_kwargs: Unpack[PlanKwargs], ) -> str: - """Eequantize the data and write to a new file. + """Requantize the data and write to a new file. Parameters ---------- nbits_out : int number of bits into requantize the data - remove_bandpass : bool, optional - remove the bandpass from the data, by default False filename : str, optional name of output file, by default ``basename_digi.fil`` + remove_bandpass : bool, optional + remove the bandpass from the data, by default False **plan_kwargs : dict Keyword arguments for :func:`read_plan`. @@ -610,7 +668,8 @@ def requantize( If ``nbits_out`` is less than 1 or greater than 32. """ if nbits_out not in {1, 2, 4, 8, 16, 32}: - raise ValueError("nbits_out must be one of {1, 2, 4, 8, 16, 32}") + msg = f"nbits_out must be one of {1, 2, 4, 8, 16, 32}, got {nbits_out}" + raise ValueError(msg) if filename is None: filename = f"{self.header.basename}_digi.fil" @@ -620,7 +679,11 @@ def requantize( out_file.close() return out_file.name - def remove_zerodm(self, filename: str | None = None, **plan_kwargs): + def remove_zerodm( + self, + filename: str | None = None, + **plan_kwargs: Unpack[PlanKwargs], + ) -> str: """Remove the channel-weighted zero-DM from the data and write to disk. Parameters @@ -641,7 +704,8 @@ def remove_zerodm(self, filename: str | None = None, **plan_kwargs): References ---------- - .. [1] R. P. Eatough, E. F. Keane, A. G. Lyne, An interference removal technique for radio pulsar searches, + .. [1] R. P. Eatough, E. F. Keane, A. G. Lyne, An interference removal + technique for radio pulsar searches, MNRAS, Volume 395, Issue 1, May 2009, Pages 410-415. """ if filename is None: @@ -650,19 +714,30 @@ def remove_zerodm(self, filename: str | None = None, **plan_kwargs): bpass = self.bandpass(**plan_kwargs) chanwts = bpass / bpass.sum() out_ar = np.empty( - self.header.nsamples * self.header.nchans, dtype=self.header.dtype + self.header.nsamples * self.header.nchans, + dtype=self.header.dtype, ) out_file = self.header.prep_outfile(filename, nbits=self.header.nbits) for nsamps, _ii, data in self.read_plan(**plan_kwargs): kernels.remove_zerodm( - data, out_ar, bpass, chanwts, self.header.nchans, nsamps + data, + out_ar, + bpass, + chanwts, + self.header.nchans, + nsamps, ) out_file.cwrite(out_ar[: nsamps * self.header.nchans]) out_file.close() return out_file.name def subband( - self, dm: float, nsub: int, filename: str = None, gulp: int = 16384, **plan_kwargs + self, + dm: float, + nsub: int, + filename: str | None = None, + gulp: int = 16384, + **plan_kwargs: Unpack[PlanKwargs], ) -> str: """Produce a set of dedispersed subbands from the data. @@ -705,7 +780,9 @@ def subband( out_file = self.header.prep_outfile(filename, changes, nbits=32) for nsamps, _ii, data in self.read_plan( - gulp=gulp, skipback=max_delay, **plan_kwargs + gulp=gulp, + skipback=max_delay, + **plan_kwargs, ): kernels.subband( data, @@ -729,7 +806,7 @@ def fold( nints: int = 32, nbands: int = 32, gulp: int = 16384, - **plan_kwargs, + **plan_kwargs: Unpack[PlanKwargs], ) -> FoldedData: """Fold data into discrete phase, subintegration and subband bins. @@ -772,14 +849,19 @@ def fold( if nbins > period / self.header.tsamp: warnings.warn("Number of phase bins is greater than period/sampling time") if (self.header.nsamples * self.header.nchans) // (nbands * nints * nbins) < 10: - raise ValueError("nbands x nints x nbins is too large.") + msg = f"nbands x nints x nbins is too large: {nbands*nints*nbins}" + raise ValueError(msg) nbands = min(nbands, self.header.nchans) chan_delays = self.header.get_dmdelays(dm) max_delay = int(chan_delays.max()) gulp = max(2 * max_delay, gulp) fold_ar = np.zeros(nbins * nints * nbands, dtype="float32") count_ar = np.zeros(nbins * nints * nbands, dtype="int32") - for nsamps, ii, data in self.read_plan(gulp, skipback=max_delay, **plan_kwargs): + for nsamps, ii, data in self.read_plan( + gulp=gulp, + skipback=max_delay, + **plan_kwargs, + ): kernels.fold( data, fold_ar, @@ -805,10 +887,10 @@ def clean_rfi( self, method: str = "mad", threshold: float = 3, - chanmask: npt.ArrayLike | None = None, - custom_funcn: Callable[[npt.ArrayLike], np.ndarray] | None = None, + chanmask: np.ndarray | None = None, + custom_funcn: Callable[[np.ndarray], np.ndarray] | None = None, filename: str | None = None, - **plan_kwargs, + **plan_kwargs: Unpack[PlanKwargs], ) -> tuple[str, RFIMask]: """Clean RFI from the data. @@ -840,13 +922,16 @@ def clean_rfi( if chanmask is None: chanmask = np.zeros(self.header.nchans, dtype="bool") if method not in {"mad", "iqrm"}: - raise ValueError("Clean method must be 'mad' or 'iqrm'") + msg = f"Clean method must be 'mad' or 'iqrm', got {method}" + raise ValueError(msg) if self.chan_stats is None: # 1st pass to compute channel statistics (upto kurtosis) self.compute_stats(**plan_kwargs) - assert isinstance(self.chan_stats, stats.ChannelStats) + if not isinstance(self.chan_stats, ChannelStats): + msg = "Channel statistics not computed properly" + raise TypeError(msg) # Initialise mask rfimask = RFIMask( threshold, @@ -866,6 +951,9 @@ def clean_rfi( maskvalue = 0 # Apply the channel mask out_file = self.apply_channel_mask( - rfimask.chan_mask, maskvalue, filename=filename, **plan_kwargs + rfimask.chan_mask, + maskvalue, + filename=filename, + **plan_kwargs, ) return out_file, rfimask diff --git a/sigpyproc/block.py b/sigpyproc/block.py index 97af924..7f87bab 100644 --- a/sigpyproc/block.py +++ b/sigpyproc/block.py @@ -1,12 +1,18 @@ from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np -from numpy import typing as npt -from sigpyproc.header import Header -from sigpyproc.timeseries import TimeSeries from sigpyproc.core import kernels +from sigpyproc.timeseries import TimeSeries from sigpyproc.utils import roll_array +if TYPE_CHECKING: + from typing_extensions import Self + + from sigpyproc.header import Header + class FilterbankBlock(np.ndarray): """An array class to handle a discrete block of data in time-major order. @@ -31,15 +37,18 @@ class FilterbankBlock(np.ndarray): """ def __new__( - cls, input_array: npt.ArrayLike, header: Header, dm: float = 0 - ) -> FilterbankBlock: + cls, + input_array: np.ndarray, + header: Header, + dm: float = 0, + ) -> Self: """Create a new block array.""" obj = np.asarray(input_array).astype(np.float32, copy=False).view(cls) obj.header = header obj.dm = dm return obj - def __array_finalize__(self, obj): + def __array_finalize__(self, obj: Self | None) -> None: if obj is None: return self.header = getattr(obj, "header", None) @@ -68,13 +77,22 @@ def downsample(self, tfactor: int = 1, ffactor: int = 1) -> FilterbankBlock: If number of time samples is not divisible by `tfactor`. """ if self.shape[0] % ffactor != 0: - raise ValueError("Bad frequency factor given") + msg = f"Bad frequency factor given: {ffactor}" + raise ValueError(msg) if self.shape[1] % tfactor != 0: - raise ValueError("Bad time factor given") + msg = f"Bad time factor given: {tfactor}" + raise ValueError(msg) ar = self.transpose().ravel().copy() - new_ar = kernels.downsample_2d(ar, tfactor, ffactor, self.shape[0], self.shape[1]) + new_ar = kernels.downsample_2d( + ar, + tfactor, + ffactor, + self.shape[0], + self.shape[1], + ) new_ar = new_ar.reshape( - self.shape[1] // tfactor, self.shape[0] // ffactor + self.shape[1] // tfactor, + self.shape[0] // ffactor, ).transpose() changes = { "tsamp": self.header.tsamp * tfactor, @@ -84,7 +102,7 @@ def downsample(self, tfactor: int = 1, ffactor: int = 1) -> FilterbankBlock: } return FilterbankBlock(new_ar, self.header.new_header(changes)) - def normalise(self, by: str = "mean", chans: bool = True) -> FilterbankBlock: + def normalise(self, *, by: str = "mean", chans: bool = True) -> FilterbankBlock: """Normalise the data block (Subtract mean/median, divide by std). Parameters @@ -98,7 +116,15 @@ def normalise(self, by: str = "mean", chans: bool = True) -> FilterbankBlock: ------- FilterbankBlock A normalised version of the data block + + Raises + ------ + ValueError + If `by` is not "mean" or "median". """ + if by not in {"mean", "median"}: + msg = f"Unknown normalisation method: {by}" + raise ValueError(msg) np_op = getattr(np, by) if chans: norm_block = self - np_op(self, axis=1)[:, np.newaxis] @@ -121,7 +147,7 @@ def get_tim(self) -> TimeSeries: ts = self.sum(axis=0) return TimeSeries(ts, self.header.dedispersed_header(dm=self.dm)) - def get_bandpass(self) -> npt.ArrayLike: + def get_bandpass(self) -> np.ndarray: """Average across each time sample for all frequencies. Returns @@ -132,7 +158,11 @@ def get_bandpass(self) -> npt.ArrayLike: return self.sum(axis=1) def dedisperse( - self, dm: float, only_valid_samples: bool = False, ref_freq: str = "ch1" + self, + dm: float, + *, + only_valid_samples: bool = False, + ref_freq: str = "ch1", ) -> FilterbankBlock: """Dedisperse the block. @@ -165,10 +195,11 @@ def dedisperse( if only_valid_samples: valid_samps = self.shape[1] - delays[-1] if valid_samps < 0: - raise ValueError( - f"Insufficient time samples to dedisperse to {dm} (requires " - + f"at least {delays[-1]} samples, given {self.shape[1]})." + msg = ( + f"Insufficient time samples to dedisperse to {dm} (requires at " + f"least {delays[-1]} samples, given {self.shape[1]})." ) + raise ValueError(msg) new_ar = np.empty((self.shape[0], valid_samps), dtype=self.dtype) for ichan in range(self.shape[0]): new_ar[ichan] = self[ichan, delays[ichan] : delays[ichan] + valid_samps] @@ -179,9 +210,12 @@ def dedisperse( return FilterbankBlock(new_ar, self.header.new_header(), dm=dm) def dmt_transform( - self, dm: float, dmsteps: int = 512, ref_freq: str = "ch1" + self, + dm: float, + dmsteps: int = 512, + ref_freq: str = "ch1", ) -> FilterbankBlock: - """Generate a DM-time transform of the data block by dedispersing at adjacent DM values. + """Generate a DM-time transform by dedispersing data block at adjacent DMs. Parameters ---------- @@ -196,14 +230,14 @@ def dmt_transform( ------- FilterbankBlock 2 dimensional array of DM-time transform - """ + """ dm_arr = dm + np.linspace(-dm, dm, dmsteps) new_ar = np.empty((dmsteps, self.shape[1]), dtype=self.dtype) for idm, dm_val in enumerate(dm_arr): new_ar[idm] = self.dedisperse(dm_val, ref_freq=ref_freq).get_tim() return FilterbankBlock(new_ar, self.header.new_header({"nchans": 1}), dm=dm) - def to_file(self, filename: str = None) -> str: + def to_file(self, filename: str | None = None) -> str: """Write the data to file. Parameters @@ -218,9 +252,7 @@ def to_file(self, filename: str = None) -> str: """ if filename is None: mjd_after = self.header.mjd_after_nsamps(self.shape[1]) - filename = ( - f"{self.header.basename}_{self.header.tstart:.12f}_to_{mjd_after:.12f}.fil" - ) + filename = f"{self.header.basename}_{self.header.tstart:.12f}_to_{mjd_after:.12f}.fil" changes = {"nbits": 32} out_file = self.header.prep_outfile(filename, changes, nbits=32) out_file.cwrite(self.transpose().ravel()) diff --git a/sigpyproc/core/kernels.py b/sigpyproc/core/kernels.py index 90b79a0..5f709f1 100644 --- a/sigpyproc/core/kernels.py +++ b/sigpyproc/core/kernels.py @@ -1,75 +1,181 @@ +from collections.abc import Callable + import numpy as np from numba import njit, prange, types -from numba.extending import overload from numba.experimental import jitclass -from scipy import constants +from numba.extending import overload + +CONST_C_VAL = 299792458.0 # Speed of light in m/s (astropy.constants.c.value) + + +def packunpack_njit(func: types.FunctionType) -> types.FunctionType: + return njit( + "void(u1[::1], u1[::1])", + cache=True, + parallel=True, + fastmath=True, + locals={"pos": types.i8}, + )(func) -@njit("u1[:](u1[:], u1[:])", cache=True, parallel=True) -def unpack1_8(array, unpacked): - bitfact = 8 +def packunpack_njit_serial(func: types.FunctionType) -> types.FunctionType: + return njit( + "void(u1[::1], u1[::1])", + cache=True, + parallel=False, + fastmath=True, + locals={"pos": types.i8}, + )(func) + + +@packunpack_njit +def unpack1_8_big(array: np.ndarray, unpacked: np.ndarray) -> None: for ii in prange(array.size): - unpacked[ii * bitfact + 0] = (array[ii] >> 7) & 1 - unpacked[ii * bitfact + 1] = (array[ii] >> 6) & 1 - unpacked[ii * bitfact + 2] = (array[ii] >> 5) & 1 - unpacked[ii * bitfact + 3] = (array[ii] >> 4) & 1 - unpacked[ii * bitfact + 4] = (array[ii] >> 3) & 1 - unpacked[ii * bitfact + 5] = (array[ii] >> 2) & 1 - unpacked[ii * bitfact + 6] = (array[ii] >> 1) & 1 - unpacked[ii * bitfact + 7] = (array[ii] >> 0) & 1 - return unpacked - - -@njit("u1[:](u1[:], u1[:])", cache=True, parallel=True) -def unpack2_8(array, unpacked): - bitfact = 8 // 2 + pos = ii * 8 + for jj in range(8): + unpacked[pos + (7 - jj)] = (array[ii] >> jj) & 1 + + +@packunpack_njit +def unpack1_8_little(array: np.ndarray, unpacked: np.ndarray) -> None: for ii in prange(array.size): - unpacked[ii * bitfact + 0] = (array[ii] & 0xC0) >> 6 - unpacked[ii * bitfact + 1] = (array[ii] & 0x30) >> 4 - unpacked[ii * bitfact + 2] = (array[ii] & 0x0C) >> 2 - unpacked[ii * bitfact + 3] = (array[ii] & 0x03) >> 0 - return unpacked + pos = ii * 8 + for jj in range(8): + unpacked[pos + jj] = (array[ii] >> jj) & 1 -@njit("u1[:](u1[:], u1[:])", cache=True, parallel=True) -def unpack4_8(array, unpacked): - bitfact = 8 // 4 +@packunpack_njit +def unpack2_8_big(array: np.ndarray, unpacked: np.ndarray) -> None: for ii in prange(array.size): - unpacked[ii * bitfact + 0] = (array[ii] & 0xF0) >> 4 - unpacked[ii * bitfact + 1] = (array[ii] & 0x0F) >> 0 + pos = ii * 4 + unpacked[pos + 3] = (array[ii] & 0x03) >> 0 + unpacked[pos + 2] = (array[ii] & 0x0C) >> 2 + unpacked[pos + 1] = (array[ii] & 0x30) >> 4 + unpacked[pos + 0] = (array[ii] & 0xC0) >> 6 - return unpacked +@packunpack_njit +def unpack2_8_little(array: np.ndarray, unpacked: np.ndarray) -> None: + for ii in prange(array.size): + pos = ii * 4 + unpacked[pos + 0] = (array[ii] & 0x03) >> 0 + unpacked[pos + 1] = (array[ii] & 0x0C) >> 2 + unpacked[pos + 2] = (array[ii] & 0x30) >> 4 + unpacked[pos + 3] = (array[ii] & 0xC0) >> 6 -@njit("u1[:](u1[:])", cache=True, parallel=True) -def pack2_8(array): - bitfact = 8 // 2 - packed = np.zeros(shape=array.size // bitfact, dtype=np.uint8) - for ii in prange(array.size // bitfact): + +@packunpack_njit +def unpack4_8_big(array: np.ndarray, unpacked: np.ndarray) -> None: + for ii in prange(array.size): + pos = ii * 2 + unpacked[pos + 1] = (array[ii] & 0x0F) >> 0 + unpacked[pos + 0] = (array[ii] & 0xF0) >> 4 + + +@packunpack_njit +def unpack4_8_little(array: np.ndarray, unpacked: np.ndarray) -> None: + for ii in prange(array.size): + pos = ii * 2 + unpacked[pos + 0] = (array[ii] & 0x0F) >> 0 + unpacked[pos + 1] = (array[ii] & 0xF0) >> 4 + + +@packunpack_njit +def pack1_8_big(array: np.ndarray, packed: np.ndarray) -> None: + for ii in prange(packed.size): + pos = ii * 8 packed[ii] = ( - (array[ii * 4] << 6) - | (array[ii * 4 + 1] << 4) - | (array[ii * 4 + 2] << 2) - | array[ii * 4 + 3] + (array[pos + 0] << 7) + | (array[pos + 1] << 6) + | (array[pos + 2] << 5) + | (array[pos + 3] << 4) + | (array[pos + 4] << 3) + | (array[pos + 5] << 2) + | (array[pos + 6] << 1) + | array[pos + 7] ) - return packed +@packunpack_njit +def pack1_8_little(array: np.ndarray, packed: np.ndarray) -> None: + for ii in prange(packed.size): + pos = ii * 8 + packed[ii] = ( + (array[pos + 7] << 7) + | (array[pos + 6] << 6) + | (array[pos + 5] << 5) + | (array[pos + 4] << 4) + | (array[pos + 3] << 3) + | (array[pos + 2] << 2) + | (array[pos + 1] << 1) + | array[pos + 0] + ) -@njit("u1[:](u1[:])", cache=True, parallel=True) -def pack4_8(array): - bitfact = 8 // 4 - packed = np.zeros(shape=array.size // bitfact, dtype=np.uint8) - for ii in prange(array.size // bitfact): - packed[ii] = (array[ii * 2] << 4) | array[ii * 2 + 1] - return packed +@packunpack_njit +def pack2_8_big(array: np.ndarray, packed: np.ndarray) -> None: + for ii in prange(packed.size): + pos = ii * 4 + packed[ii] = ( + (array[pos + 0] << 6) + | (array[pos + 1] << 4) + | (array[pos + 2] << 2) + | array[pos + 3] + ) + + +@packunpack_njit +def pack2_8_little(array: np.ndarray, packed: np.ndarray) -> None: + for ii in prange(packed.size): + pos = ii * 4 + packed[ii] = ( + (array[pos + 3] << 6) + | (array[pos + 2] << 4) + | (array[pos + 1] << 2) + | array[pos + 0] + ) + + +@packunpack_njit +def pack4_8_big(array: np.ndarray, packed: np.ndarray) -> None: + for ii in prange(packed.size): + pos = ii * 2 + packed[ii] = (array[pos + 0] << 4) | array[pos + 1] + + +@packunpack_njit +def pack4_8_little(array: np.ndarray, packed: np.ndarray) -> None: + for ii in prange(packed.size): + pos = ii * 2 + packed[ii] = (array[pos + 1] << 4) | array[pos + 0] + + +unpack1_8_big_serial = packunpack_njit_serial(unpack1_8_big.py_func) +unpack1_8_little_serial = packunpack_njit_serial(unpack1_8_little.py_func) +unpack2_8_big_serial = packunpack_njit_serial(unpack2_8_big.py_func) +unpack2_8_little_serial = packunpack_njit_serial(unpack2_8_little.py_func) +unpack4_8_big_serial = packunpack_njit_serial(unpack4_8_big.py_func) +unpack4_8_little_serial = packunpack_njit_serial(unpack4_8_little.py_func) +pack1_8_big_serial = packunpack_njit_serial(pack1_8_big.py_func) +pack1_8_little_serial = packunpack_njit_serial(pack1_8_little.py_func) +pack2_8_big_serial = packunpack_njit_serial(pack2_8_big.py_func) +pack2_8_little_serial = packunpack_njit_serial(pack2_8_little.py_func) +pack4_8_big_serial = packunpack_njit_serial(pack4_8_big.py_func) +pack4_8_little_serial = packunpack_njit_serial(pack4_8_little.py_func) @njit(cache=True) -def np_apply_along_axis(func1d, axis, arr): - assert arr.ndim == 2 - assert axis in {0, 1} +def np_apply_along_axis( + func1d: Callable[[np.ndarray], np.ndarray], + axis: int, + arr: np.ndarray, +) -> np.ndarray: + if arr.ndim != 2: + msg = f"np_apply_along_axis only works on 2D arrays, got {arr.ndim}" + raise ValueError(msg) + if axis not in {0, 1}: + msg = f"axis should be 0 or 1, got {axis}" + raise ValueError(msg) if axis == 0: result = np.empty(arr.shape[1], dtype=arr.dtype) for ii in range(arr.shape[1]): @@ -82,7 +188,7 @@ def np_apply_along_axis(func1d, axis, arr): @njit(cache=True) -def np_mean(array, axis): +def np_mean(array: np.ndarray, axis: int) -> np.ndarray: return np_apply_along_axis(np.mean, axis, array) @@ -100,7 +206,7 @@ def ol_downcast(intype, result): @njit(cache=True) -def downsample_1d(array, factor): +def downsample_1d(array: np.ndarray, factor: int) -> np.ndarray: reshaped_ar = np.reshape(array, (array.size // factor, factor)) return np_mean(reshaped_ar, 1) @@ -111,7 +217,13 @@ def downsample_1d(array, factor): parallel=True, locals={"temp": types.f8}, ) -def downsample_2d(array, tfactor, ffactor, nchans, nsamps): +def downsample_2d( + array: np.ndarray, + tfactor: int, + ffactor: int, + nchans: int, + nsamps: int, +) -> np.ndarray: nsamps_new = nsamps // tfactor nchans_new = nchans // ffactor totfactor = ffactor * tfactor @@ -133,7 +245,13 @@ def downsample_2d(array, tfactor, ffactor, nchans, nsamps): cache=True, parallel=True, ) -def extract_tim(inarray, outarray, nchans, nsamps, index): +def extract_tim( + inarray: np.ndarray, + outarray: np.ndarray, + nchans: int, + nsamps: int, + index: int, +) -> None: for isamp in prange(nsamps): for ichan in range(nchans): outarray[index + isamp] += inarray[nchans * isamp + ichan] @@ -236,7 +354,9 @@ def fold( tobs = total_nsamps * tsamp for isamp in range(nsamps - maxdelay): tj = (isamp + index) * tsamp - phase = nbins * tj * (1 + accel * (tj - tobs) / (2 * constants.c)) / period + 0.5 + phase = ( + nbins * tj * (1 + accel * (tj - tobs) / (2 * CONST_C_VAL)) / period + 0.5 + ) phasebin = abs(int(phase)) % nbins subint = (isamp + index) // factor1 pos1 = (subint * nbins * nsubs) + phasebin @@ -254,7 +374,7 @@ def resample_tim(array, accel, tsamp): nsamps = len(array) - 1 if accel > 0 else len(array) resampled = np.zeros(nsamps, dtype=array.dtype) - partial_calc = (accel * tsamp) / (2 * constants.c) + partial_calc = (accel * tsamp) / (2 * CONST_C_VAL) tot_drift = partial_calc * (nsamps // 2) ** 2 last_bin = 0 for ii in range(nsamps): @@ -297,7 +417,7 @@ def form_spec(fft_ar, interpolated=False): for ispec in range(specsize): rr = fft_ar[2 * ispec] ii = fft_ar[2 * ispec + 1] - aa = rr ** 2 + ii ** 2 + aa = rr**2 + ii**2 bb = ((rr - rl) ** 2 + (ii - il) ** 2) / 2 spec_arr[ispec] = np.sqrt(max(aa, bb)) @@ -305,7 +425,9 @@ def form_spec(fft_ar, interpolated=False): il = ii else: for ispec in range(specsize): - spec_arr[ispec] = np.sqrt(fft_ar[2 * ispec] ** 2 + fft_ar[2 * ispec + 1] ** 2) + spec_arr[ispec] = np.sqrt( + fft_ar[2 * ispec] ** 2 + fft_ar[2 * ispec + 1] ** 2 + ) return spec_arr @@ -367,9 +489,9 @@ def remove_rednoise(fftbuffer, startwidth, endwidth, endfreq, tsamp): oldinbuf[: 2 * numread_new] = newinbuf[: 2 * numread_new] - outbuffer[windex : windex + 2 * numread_old] = oldinbuf[: 2 * numread_old] / np.sqrt( - mean_old - ) + outbuffer[windex : windex + 2 * numread_old] = oldinbuf[ + : 2 * numread_old + ] / np.sqrt(mean_old) return outbuffer @@ -385,21 +507,20 @@ def sum_harms(spec_arr, sum_arr, harm_arr, fact_arr, nharms, nsamps, nfold): fact_arr[kk] += 2 * kk + 1 -MomentsBagSpec = [ - ("nchans", types.i4), - ("m1", types.f4[:]), - ("m2", types.f4[:]), - ("m3", types.f4[:]), - ("m4", types.f4[:]), - ("min", types.f4[:]), - ("max", types.f4[:]), - ("count", types.i4[:]), -] - - -@jitclass(MomentsBagSpec) -class MomentsBag(object): - def __init__(self, nchans): +@jitclass( + [ + ("nchans", types.i4), + ("m1", types.f4[:]), + ("m2", types.f4[:]), + ("m3", types.f4[:]), + ("m4", types.f4[:]), + ("min", types.f4[:]), + ("max", types.f4[:]), + ("count", types.i4[:]), + ], +) +class MomentsBag: + def __init__(self, nchans: int) -> None: self.nchans = nchans self.m1 = np.zeros(nchans, dtype=np.float32) self.m2 = np.zeros(nchans, dtype=np.float32) @@ -411,7 +532,12 @@ def __init__(self, nchans): @njit(cache=True, parallel=True, locals={"val": types.f8}) -def compute_online_moments_basic(array, bag, nsamps, startflag): +def compute_online_moments_basic( + array: np.ndarray, + bag: MomentsBag, + nsamps: int, + startflag: int, +) -> None: if startflag == 0: for ii in range(bag.nchans): bag.max[ii] = array[ii] @@ -433,8 +559,13 @@ def compute_online_moments_basic(array, bag, nsamps, startflag): @njit(cache=True, parallel=True, locals={"val": types.f8}) -def compute_online_moments(array, bag, nsamps, startflag): - """Computing central moments in one pass through the data.""" +def compute_online_moments( + array: np.ndarray, + bag: MomentsBag, + nsamps: int, + startflag: int, +) -> None: + """Compute central moments in one pass through the data.""" if startflag == 0: for ii in range(bag.nchans): bag.max[ii] = array[ii] @@ -464,14 +595,16 @@ def compute_online_moments(array, bag, nsamps, startflag): @njit(cache=True, parallel=False, locals={"val": types.f8}) -def add_online_moments(bag_a, bag_b, bag_c): +def add_online_moments(bag_a: MomentsBag, bag_b: MomentsBag, bag_c: MomentsBag) -> None: bag_c.count = bag_a.count + bag_b.count delta = bag_b.m1 - bag_a.m1 delta2 = delta * delta delta3 = delta * delta2 delta4 = delta2 * delta2 - bag_c.m1 = (bag_a.count * bag_a.m1 + bag_b.count * bag_b.m1) / bag_c.count + bag_c.m1 = ( + bag_a.count * bag_a.m1 / bag_c.count + bag_b.count * bag_b.m1 / bag_c.count + ) bag_c.m2 = bag_a.m2 + bag_b.m2 + delta2 * bag_a.count * bag_b.count / bag_c.count bag_c.m3 = ( @@ -481,7 +614,7 @@ def add_online_moments(bag_a, bag_b, bag_c): * bag_a.count * bag_b.count * (bag_a.count - bag_b.count) - / (bag_c.count ** 2) + / (bag_c.count**2) ) bag_c.m3 += ( 3 * delta * (bag_a.count * bag_b.m2 - bag_b.count * bag_a.m2) / bag_c.count @@ -493,14 +626,14 @@ def add_online_moments(bag_a, bag_b, bag_c): + delta4 * bag_a.count * bag_b.count - * (bag_a.count ** 2 - bag_a.count * bag_b.count + bag_b.count ** 2) - / (bag_c.count ** 3) + * (bag_a.count**2 - bag_a.count * bag_b.count + bag_b.count**2) + / (bag_c.count**3) ) bag_c.m4 += ( 6 * delta2 * (bag_a.count * bag_a.count * bag_b.m2 + bag_b.count * bag_b.count * bag_a.m2) - / (bag_c.count ** 2) + / (bag_c.count**2) ) bag_c.m4 += ( 4 * delta * (bag_a.count * bag_b.m3 - bag_b.count * bag_a.m3) / bag_c.count diff --git a/sigpyproc/core/rfi.py b/sigpyproc/core/rfi.py index 39d8d96..6e0931c 100644 --- a/sigpyproc/core/rfi.py +++ b/sigpyproc/core/rfi.py @@ -1,25 +1,55 @@ from __future__ import annotations -import numpy as np + +from typing import TYPE_CHECKING + import attrs import h5py +import numpy as np -from numpy import typing as npt -from typing import Callable - -from iqrm import iqrm_mask +from sigpyproc.core import stats from sigpyproc.header import Header -from sigpyproc.core.stats import zscore_double_mad + +if TYPE_CHECKING: + from typing import Callable -def double_mad_mask(array: npt.ArrayLike, threshold: float = 3) -> np.ndarray: +def double_mad_mask(array: np.ndarray, threshold: float = 3) -> np.ndarray: """Calculate the mask of an array using the double MAD (Modified z-score). Parameters ---------- - array : :py:obj:`~numpy.typing.ArrayLike` + array : :py:obj:`~numpy.ndarray` + The array to calculate the mask of. + threshold : float, optional + Threshold in sigmas, by default 3.0 + + Returns + ------- + :py:obj:`~numpy.ndarray` + The mask for the array. + + Raises + ------ + ValueError + If the threshold is not positive. + """ + if threshold <= 0: + msg = f"threshold must be positive, got {threshold}" + raise ValueError(msg) + return np.abs(stats.zscore_double_mad(array)) > threshold + + +def iqrm_mask(array: np.ndarray, threshold: float = 3, radius: int = 5) -> np.ndarray: + """Calculate the mask of an array using the IQRM (Interquartile Range Method). + + Parameters + ---------- + array : :py:obj:`~numpy.ndarray` The array to calculate the mask of. threshold : float, optional Threshold in sigmas, by default 3.0 + radius : int, optional + Radius to calculate the IQRM, by default 5 Returns ------- @@ -31,14 +61,25 @@ def double_mad_mask(array: npt.ArrayLike, threshold: float = 3) -> np.ndarray: ValueError If the threshold is not positive. """ - array = np.asarray(array) if threshold <= 0: - raise ValueError("threshold must be positive") - return np.abs(zscore_double_mad(array)) > threshold + msg = f"threshold must be positive, got {threshold}" + raise ValueError(msg) + mask = np.zeros_like(array, dtype="bool") + lags = np.concatenate([np.arange(-radius, 0), np.arange(1, radius + 1)]) + shifted_x = np.lib.stride_tricks.as_strided( + np.pad(array, radius, mode="edge"), + shape=(len(array), 2 * radius + 1), + strides=array.strides * 2, + ) + lagged_diffs = array[:, np.newaxis] - shifted_x[:, lags + radius] + lagged_diffs = lagged_diffs.T + for lagged_diff in lagged_diffs: + mask = np.logical_or(mask, np.abs(stats.zscore_iqr(lagged_diff)) > threshold) + return mask @attrs.define(auto_attribs=True, slots=True) -class RFIMask(object): +class RFIMask: threshold: float header: Header chan_mean: np.ndarray @@ -51,7 +92,7 @@ class RFIMask(object): chan_mask: np.ndarray = attrs.field() @chan_mask.default - def _set_chan_mask(self): + def _set_chan_mask(self) -> np.ndarray: return np.zeros(self.header.nchans, dtype="bool") @property @@ -64,7 +105,7 @@ def masked_fraction(self) -> float: """float: Fraction of channels masked.""" return self.num_masked * 100 / self.header.nchans - def apply_mask(self, chanmask: npt.ArrayLike) -> None: + def apply_mask(self, chanmask: np.ndarray) -> None: """Apply a channel mask to the current mask. Parameters @@ -79,12 +120,11 @@ def apply_mask(self, chanmask: npt.ArrayLike) -> None: """ chanmask = np.asarray(chanmask, dtype="bool") if chanmask.size != self.header.nchans: - raise ValueError( - f"chanmask len {chanmask.size} does not match nchans {self.header.nchans}" - ) + msg = f"chanmask ({chanmask.size}) not equal nchans ({self.header.nchans})" + raise ValueError(msg) self.chan_mask = np.logical_or(self.chan_mask, chanmask) - def apply_method(self, method: str) -> None: + def apply_method(self, method: str = "mad") -> None: """Apply a mask method using channel statistics. Parameters @@ -97,21 +137,20 @@ def apply_method(self, method: str) -> None: ValueError If the method is not supported. """ - if method == "iqrm": - method_funcn = lambda arr, thres: iqrm_mask( # noqa: E731 - arr, radius=0.1 * self.header.nchans, threshold=thres - ) - elif method == "mad": + if method == "mad": method_funcn = double_mad_mask + elif method == "iqrm": + method_funcn = iqrm_mask else: - raise ValueError(f"Unknown method {method}") + msg = f"method {method} not supported" + raise ValueError(msg) mask_var = method_funcn(self.chan_var, self.threshold) mask_skew = method_funcn(self.chan_skew, self.threshold) mask_kurtosis = method_funcn(self.chan_kurtosis, self.threshold) mask_stats = np.logical_or.reduce((mask_var, mask_skew, mask_kurtosis)) self.chan_mask = np.logical_or(self.chan_mask, mask_stats) - def apply_funcn(self, custom_funcn: Callable[[npt.ArrayLike], np.ndarray]) -> None: + def apply_funcn(self, custom_funcn: Callable[[np.ndarray], np.ndarray]) -> None: """Apply a custom function to the channel mask. Parameters @@ -125,7 +164,8 @@ def apply_funcn(self, custom_funcn: Callable[[npt.ArrayLike], np.ndarray]) -> No If the custom_funcn is not callable. """ if not callable(custom_funcn): - raise ValueError(f"{custom_funcn} is not callable") + msg = f"{custom_funcn} is not callable" + raise TypeError(msg) self.chan_mask = custom_funcn(self.chan_mask) def to_file(self, filename: str | None = None) -> str: @@ -168,12 +208,12 @@ def from_file(cls, filename: str) -> RFIMask: The loaded mask. """ with h5py.File(filename, "r") as fp: - fp_attrs = {key: val for key, val in fp.attrs.items()} + fp_attrs = dict(fp.attrs.items()) fp_stats = {key: np.array(val) for key, val in fp.items()} hdr_checked = { key: value for key, value in fp_attrs.items() - if key in attrs.fields_dict(Header).keys() + if key in attrs.fields_dict(Header) } kws = { "header": Header(**hdr_checked), diff --git a/sigpyproc/core/stats.py b/sigpyproc/core/stats.py index 2dd1114..2b86716 100644 --- a/sigpyproc/core/stats.py +++ b/sigpyproc/core/stats.py @@ -1,60 +1,77 @@ from __future__ import annotations -import numpy as np + import bottleneck as bn -from numpy import typing as npt +import numpy as np from sigpyproc.core import kernels -def running_median(array, window): +def running_filter( + array: np.ndarray, + window: int, + filter_func: str = "mean", +) -> np.ndarray: """ - Calculate the running median of an array. + Calculate the running filter of an array. Parameters ---------- array : numpy.ndarray - The array to calculate the running median of. + The array to calculate the running filter of. + window : int + The window size of the filter. + filter_func : str, optional + The filter function to use, by default "mean". Returns ------- numpy.ndarray - The running median of the array. + The running filter of the array. + + Raises + ------ + ValueError + If the filter function is not "mean" or "median". + + Notes + ----- + Window edges are handled by reflecting about the edges. """ pad_size = ( (window // 2, window // 2) if window % 2 else (window // 2, window // 2 - 1) ) - padded = np.pad(array, pad_size, "symmetric") - - median = bn.move_median(padded, window) - return median[window - 1 :] - + padded_ar = np.pad(array, pad_size, "symmetric") + if filter_func == "mean": + filtered_ar = bn.move_mean(padded_ar, window) + elif filter_func == "median": + filtered_ar = bn.move_median(padded_ar, window) + else: + msg = f"Filter function not recognized: {filter_func}" + raise ValueError(msg) + return filtered_ar[window - 1 :] -def running_mean(array, window): - """ - Calculate the running mean of an array. +def zscore_iqr(array: np.ndarray) -> np.ndarray: + """Calculate the z-score of an array using the IQR (Interquartile Range). Parameters ---------- - array : numpy.ndarray - The array to calculate the running mean of. + array : :py:obj:`~numpy.ndarray` + The array to calculate the z-score of. Returns ------- - numpy.ndarray - The running mean of the array. + :py:obj:`~numpy.ndarray` + The z-score of the array. """ - pad_size = ( - (window // 2, window // 2) if window % 2 else (window // 2, window // 2 - 1) - ) - padded = np.pad(array, pad_size, "symmetric") - - mean = bn.move_mean(padded, window) - return mean[window - 1 :] + q1, median, q3 = np.percentile(array, [25, 50, 75]) + iqr = (q3 - q1) / 1.349 + diff = array - median + return np.divide(diff, iqr, out=np.zeros_like(diff), where=iqr != 0) -def zscore_mad(array: npt.ArrayLike) -> np.ndarray: +def zscore_mad(array: np.ndarray) -> np.ndarray: """Calculate the z-score of an array using the MAD (Modified z-score). Parameters @@ -74,18 +91,13 @@ def zscore_mad(array: npt.ArrayLike) -> np.ndarray: """ scale_mad = 0.6744897501960817 # scipy.stats.norm.ppf(3/4.) scale_aad = np.sqrt(2 / np.pi) - array = np.asarray(array) - med = np.median(array) - diff = array - med + diff = array - np.median(array) mad = np.median(np.abs(diff)) / scale_mad - if mad == 0: - std = np.mean(np.abs(diff)) / scale_aad - else: - std = mad + std = np.mean(np.abs(diff)) / scale_aad if mad == 0 else mad return np.divide(diff, std, out=np.zeros_like(diff), where=std != 0) -def zscore_double_mad(array: npt.ArrayLike) -> np.ndarray: +def zscore_double_mad(array: np.ndarray) -> np.ndarray: """Calculate the modified z-score of an array using the Double MAD. Parameters @@ -126,8 +138,8 @@ def zscore_double_mad(array: npt.ArrayLike) -> np.ndarray: return np.divide(diff, std_map, out=np.zeros_like(diff), where=std_map != 0) -class ChannelStats(object): - def __init__(self, nchans: int, nsamps: int): +class ChannelStats: + def __init__(self, nchans: int, nsamps: int) -> None: """Central central moments for filterbank channels in one pass. Parameters @@ -151,7 +163,7 @@ def __init__(self, nchans: int, nsamps: int): @property def mbag(self) -> kernels.MomentsBag: - """:class:`~sigpyproc.core.kernels.MomentsBag`: The central moments of the data.""" + """:class:`~sigpyproc.core.kernels.MomentsBag`: Central moments of the data.""" return self._mbag @property @@ -214,14 +226,23 @@ def kurtosis(self) -> np.ndarray: ) def push_data( - self, array: np.ndarray, gulp_size: int, start_index: int, mode: str = "basic" - ): + self, + array: np.ndarray, + gulp_size: int, + start_index: int, + mode: str = "basic", + ) -> None: if mode == "basic": - kernels.compute_online_moments_basic(array, self.mbag, gulp_size, start_index) + kernels.compute_online_moments_basic( + array, + self.mbag, + gulp_size, + start_index, + ) else: kernels.compute_online_moments(array, self.mbag, gulp_size, start_index) - def __add__(self, other: type[ChannelStats]) -> type[ChannelStats]: + def __add__(self, other: ChannelStats) -> ChannelStats: """Add two ChannelStats objects together as if all the data belonged to one. Parameters @@ -240,7 +261,8 @@ def __add__(self, other: type[ChannelStats]) -> type[ChannelStats]: If the other object is not a ChannelStats object. """ if not isinstance(other, ChannelStats): - raise TypeError("ChannelStats can only be added to other ChannelStats object") + msg = f"Only ChannelStats can be added together, not {type(other)}" + raise TypeError(msg) combined = ChannelStats(self.nchans, self.nsamps) kernels.add_online_moments(self.mbag, other.mbag, combined.mbag) diff --git a/sigpyproc/io/bits.py b/sigpyproc/io/bits.py index 465e0f6..08995e7 100644 --- a/sigpyproc/io/bits.py +++ b/sigpyproc/io/bits.py @@ -1,5 +1,7 @@ from __future__ import annotations -from typing import ClassVar, Any + +from typing import Any, ClassVar + import attrs import numpy as np @@ -9,18 +11,27 @@ def unpack( - array: np.ndarray, nbits: int, unpacked: np.ndarray | None = None + array: np.ndarray, + nbits: int, + unpacked: np.ndarray | None = None, + *, + bitorder: str = "big", + parallel: bool = False, ) -> np.ndarray: - """Unpack 1, 2 and 4 bit array. Only unpacks in big endian bit ordering. + """Unpack 1, 2 and 4-bit data packed as 8-bit numpy array. Parameters ---------- array : numpy.ndarray Array to unpack. nbits : int - Number of bits to unpack. + Number of bits of the packed data. unpacked : numpy.ndarray, optional Array to unpack into. + bitorder : str, optional + Bit order of the packed data. + parallel : bool, optional + Whether to use parallel unpacking. Returns ------- @@ -30,37 +41,55 @@ def unpack( Raises ------ ValueError + if input array is not uint8 type if nbits is not 1, 2, or 4 + if bitorder is not 'big' or 'little' + if unpacked array is not of the correct size """ if array.dtype != np.uint8: - raise ValueError(f"Input array must be uint8, got {array.dtype}") + msg = f"Input array must be uint8, got {array.dtype}" + raise ValueError(msg) + if nbits not in {1, 2, 4}: + msg = f"nbits must be 1, 2, or 4, got {nbits}" + raise ValueError(msg) + if (not bitorder) or (bitorder[0] not in {"b", "l"}): + msg = f"bitorder must be 'big' or 'little', got {bitorder}" + raise ValueError(msg) + bitorder_str = "big" if bitorder[0] == "b" else "little" + parallel_str = "" if parallel else "_serial" bitfact = 8 // nbits if unpacked is None: unpacked = np.zeros(shape=array.size * bitfact, dtype=np.uint8) elif unpacked.size != array.size * bitfact: - raise ValueError( - f"Unpacking array must be {bitfact} x input size, got {unpacked.size}" - ) - if nbits == 1: - kernels.unpack1_8(array, unpacked) - elif nbits == 2: - kernels.unpack2_8(array, unpacked) - elif nbits == 4: - kernels.unpack4_8(array, unpacked) - else: - raise ValueError(f"nbits must be 1, 2, or 4, got {nbits}") + msg = f"Unpacking array must be {bitfact} x input size, got {unpacked.size}" + raise ValueError(msg) + unpack_func = getattr(kernels, f"unpack{nbits:d}_8_{bitorder_str}{parallel_str}") + unpack_func(array, unpacked) return unpacked -def pack(array, nbits): - """Pack 1, 2 and 4 bit array. Only packs in big endian bit ordering. +def pack( + array: np.ndarray, + nbits: int, + packed: np.ndarray | None = None, + *, + bitorder: str = "big", + parallel: bool = False, +) -> np.ndarray: + """Pack 1, 2 and 4-bit data into 8-bit numpy array. Parameters ---------- array : numpy.ndarray Array to pack. nbits : int - Number of bits to pack. + Number of bits of the unpacked data. + packed : numpy.ndarray, optional + Array to pack into. + bitorder : str, optional + Bit order in which to pack the data. + parallel : bool, optional + Whether to use parallel packing. Returns ------- @@ -70,22 +99,35 @@ def pack(array, nbits): Raises ------ ValueError + if input array is not uint8 type if nbits is not 1, 2, or 4 + if bitorder is not 'big' or 'little' + if unpacked array is not of the correct size """ - assert array.dtype == np.uint8, "Array must be uint8" - if nbits == 1: - packed = np.packbits(array, bitorder="big") - elif nbits == 2: - packed = kernels.pack2_8(array) - elif nbits == 4: - packed = kernels.pack4_8(array) - else: - raise ValueError("nbits must be 1, 2, or 4") + if array.dtype != np.uint8: + msg = f"Input array must be uint8, got {array.dtype}" + raise ValueError(msg) + if nbits not in {1, 2, 4}: + msg = f"nbits must be 1, 2, or 4, got {nbits}" + raise ValueError(msg) + if (not bitorder) or (bitorder[0] not in {"b", "l"}): + msg = f"bitorder must be 'big' or 'little', got {bitorder}" + raise ValueError(msg) + bitorder_str = "big" if bitorder[0] == "b" else "little" + parallel_str = "" if parallel else "_serial" + bitfact = 8 // nbits + if packed is None: + packed = np.zeros(shape=array.size // bitfact, dtype=np.uint8) + elif packed.size != array.size // bitfact: + msg = f"packing array must be input size // {bitfact}, got {packed.size}" + raise ValueError(msg) + pack_func = getattr(kernels, f"pack{nbits:d}_8_{bitorder_str}{parallel_str}") + pack_func(array, packed) return packed @attrs.define(auto_attribs=True, frozen=True, slots=True) -class BitsInfo(object): +class BitsInfo: """Class to handle bits info. Raises @@ -157,12 +199,12 @@ def to_dict(self) -> dict[str, Any]: attributes = attrs.asdict(self) prop = { key: getattr(self, key) - for key, value in vars(type(self)).items() # noqa: WPS421 + for key, value in vars(type(self)).items() if isinstance(value, property) } attributes.update(prop) return attributes @digi_sigma.default - def _set_digi_sigma(self): + def _set_digi_sigma(self) -> float: return self.default_sigma[self.nbits] diff --git a/sigpyproc/io/fileio.py b/sigpyproc/io/fileio.py index b066b6e..adc9ace 100644 --- a/sigpyproc/io/fileio.py +++ b/sigpyproc/io/fileio.py @@ -1,15 +1,23 @@ from __future__ import annotations -from typing_extensions import Buffer -from typing import Callable + import io import os import warnings +from typing import TYPE_CHECKING + import numpy as np +from typing_extensions import Buffer -from sigpyproc.io.bits import BitsInfo, unpack, pack -from sigpyproc.io.sigproc import StreamInfo +from sigpyproc.io.bits import BitsInfo, pack, unpack from sigpyproc.utils import get_logger +if TYPE_CHECKING: + from typing import Callable + + from typing_extensions import Self + + from sigpyproc.io.sigproc import StreamInfo + def allocate_buffer(allocator: Callable[[int], Buffer], nbytes: int) -> Buffer: """Allocate a buffer of the given size safely using the given allocator. @@ -38,24 +46,29 @@ def allocate_buffer(allocator: Callable[[int], Buffer], nbytes: int) -> Buffer: if the allocated buffer is not of the expected size. """ if nbytes <= 0: - raise ValueError(f"Requested buffer size is invalid {nbytes}") + msg = f"Requested buffer size is invalid {nbytes}" + raise ValueError(msg) try: buffer = allocator(nbytes) - except Exception as exc: - raise RuntimeError(f"Failed to allocate buffer of size {nbytes}") from exc + except Exception as exc: # noqa: BLE001 + msg = f"Failed to allocate buffer of size {nbytes}" + raise RuntimeError(msg) from exc if not isinstance(buffer, Buffer): - raise TypeError(f"Allocator did not return a buffer object {type(buffer)}") + msg = f"Allocator did not return a buffer object {type(buffer)}" + raise TypeError(msg) - allocated_nbytes = len(buffer) # type: ignore + allocated_nbytes = len(buffer) # type: ignore [arg-type] if allocated_nbytes != nbytes: - raise ValueError( - f"Allocated buffer is not the expected size {allocated_nbytes} (actual) != {nbytes} (expected)" + msg = ( + f"Allocated buffer is not the expected size {allocated_nbytes} " + f"(actual) != {nbytes} (expected)" ) + raise ValueError(msg) return buffer -class FileBase(object): +class FileBase: """File I/O base class.""" def __init__(self, files: list[str], mode: str) -> None: @@ -65,10 +78,10 @@ def __init__(self, files: list[str], mode: str) -> None: self.ifile_cur = -1 self._open(ifile=0) - def __enter__(self) -> FileBase: + def __enter__(self) -> Self: return self - def __exit__(self, exc_type, exc_value, traceback) -> None: + def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 self._close_current() def _open(self, ifile: int) -> None: @@ -87,9 +100,8 @@ def _open(self, ifile: int) -> None: if ifile is out of bounds """ if ifile < 0 or ifile >= len(self.files): - raise ValueError( - f"ifile should be between 0 and {len(self.files) - 1}, got {ifile}" - ) + msg = f"ifile should be between 0 and {len(self.files) - 1}, got {ifile}" + raise ValueError(msg) if ifile != self.ifile_cur: file_obj = self.opener(self.files[ifile], mode=self.mode) @@ -126,7 +138,9 @@ class FileReader(FileBase): number of bits per sample in the files, by default 8 """ - def __init__(self, stream_info: StreamInfo, mode: str = "r", nbits: int = 8) -> None: + def __init__( + self, stream_info: StreamInfo, mode: str = "r", nbits: int = 8 + ) -> None: self.sinfo = stream_info self.nbits = nbits self.bitsinfo = BitsInfo(nbits) @@ -166,7 +180,8 @@ def cread(self, nunits: int) -> np.ndarray: if file is closed. """ if self.file_obj.closed: - raise IOError("Cannot read closed file.") + msg = f"Cannot read from closed file {self.files[self.ifile_cur]}" + raise OSError(msg) count = nunits // self.bitsinfo.bitfact data = [] @@ -187,7 +202,11 @@ def cread(self, nunits: int) -> np.ndarray: return unpack(data_ar, self.nbits) return data_ar - def creadinto(self, read_buffer: Buffer, unpack_buffer: Buffer | None = None) -> int: + def creadinto( + self, + read_buffer: Buffer, + unpack_buffer: Buffer | None = None, + ) -> int: """Read from file stream into a buffer of pre-defined length. Parameters @@ -196,7 +215,8 @@ def creadinto(self, read_buffer: Buffer, unpack_buffer: Buffer | None = None) -> An object exposing the Python Buffer Protocol interface [PEP 3118] unpack_buffer : Buffer, optional - An object exposing the Python Buffer Protocol interface [PEP 3118], by default None + An object exposing the Python Buffer Protocol interface [PEP 3118], + by default None Returns ------- @@ -215,7 +235,8 @@ def creadinto(self, read_buffer: Buffer, unpack_buffer: Buffer | None = None) -> the number of bytes returned will be zero. """ if self.file_obj.closed: - raise IOError("Cannot read closed file.") + msg = f"Cannot read from closed file {self.files[self.ifile_cur]}" + raise OSError(msg) nbytes = 0 read_buffer_view = memoryview(read_buffer) @@ -225,8 +246,10 @@ def creadinto(self, read_buffer: Buffer, unpack_buffer: Buffer | None = None) -> if self.eos(): # We have reached the end of the stream break - # Might be non-blocking IO, so maybe try again - raise IOError("file might in non-blocking mode") + else: # noqa: RET508 + # Might be non-blocking IO, so maybe try again + msg = "file might in non-blocking mode" + raise OSError(msg) else: nbytes += nbytes_read if nbytes == read_buffer_view.nbytes or self.eos(): @@ -261,7 +284,8 @@ def seek(self, offset: int, whence: int = 0) -> None: if whence is not 0 or 1. """ if self.file_obj.closed: - raise ValueError("Cannot read closed file.") + msg = f"Cannot read from closed file {self.files[self.ifile_cur]}" + raise OSError(msg) if whence == 0: self._seek_set(offset) @@ -269,7 +293,8 @@ def seek(self, offset: int, whence: int = 0) -> None: offset_start = offset + self.cur_data_pos_stream self._seek_set(offset_start) else: - raise ValueError("whence should be either 0 (SEEK_SET) or 1 (SEEK_CUR)") + msg = "whence should be either 0 (SEEK_SET) or 1 (SEEK_CUR)" + raise ValueError(msg) def _seek2hdr(self, fileid: int) -> None: """Go to the header end position of the file with the given fileid.""" @@ -278,7 +303,8 @@ def _seek2hdr(self, fileid: int) -> None: def _seek_set(self, offset: int) -> None: if offset < 0 or offset >= self.sinfo.get_combined("datalen"): - raise ValueError(f"offset out of bounds: {offset}") + msg = f"offset out of bounds: {offset}" + raise ValueError(msg) fileid = np.where(offset < self.sinfo.cumsum_datalens)[0][0] self._seek2hdr(fileid) @@ -415,7 +441,7 @@ def close(self) -> None: self._close_current() -class Transform(object): +class Transform: """A class to transform data to the quantized format. Parameters @@ -448,7 +474,7 @@ def __init__( digi_max: float, interval_seconds: float = 10, constant_offset_scale: bool = False, - ): + ) -> None: self.tsamp = tsamp self.nchans = nchans self.interval_seconds = interval_seconds diff --git a/sigpyproc/io/sigproc.py b/sigpyproc/io/sigproc.py index a6d9a71..7c50664 100644 --- a/sigpyproc/io/sigproc.py +++ b/sigpyproc/io/sigproc.py @@ -1,13 +1,15 @@ from __future__ import annotations + import struct +from pathlib import Path +from typing import IO + import attrs import numpy as np - -from bidict import bidict -from astropy.time import Time, TimeDelta -from astropy.coordinates import SkyCoord from astropy import units - +from astropy.coordinates import SkyCoord +from astropy.time import Time, TimeDelta +from bidict import bidict header_keys = { "signed": "b", @@ -49,7 +51,7 @@ "SRT": 10, "Unknown": 11, "CHIME": 20, - } + }, ) machine_ids = bidict( @@ -65,12 +67,12 @@ "PFFTS": 8, "Unknown": 9, "CHIME": 20, - } + }, ) @attrs.define(frozen=True, kw_only=True) -class FileInfo(object): +class FileInfo: """Class to handle individual file information.""" filename: str @@ -82,12 +84,12 @@ class FileInfo(object): @classmethod def from_dict(cls, info: dict) -> FileInfo: """Create FileInfo object from a dictionary.""" - info_filtered = {key: info[key] for key in attrs.fields_dict(cls).keys()} + info_filtered = {key: info[key] for key in attrs.fields_dict(cls)} return cls(**info_filtered) @attrs.define(frozen=True) -class StreamInfo(object): +class StreamInfo: """Class to handle stream information as a list of FileInfo objects.""" entries: list[FileInfo] = attrs.Factory(list) @@ -100,7 +102,8 @@ def cumsum_datalens(self) -> np.ndarray: def add_entry(self, finfo: FileInfo) -> None: """Add a FileInfo entry to the StreamInfo object.""" if not isinstance(finfo, FileInfo): - raise ValueError("Input must be a FileInfo object") + msg = f"Input must be a FileInfo object, got {type(finfo)}" + raise TypeError(msg) self.entries.append(finfo) def get_combined(self, key: str) -> int: @@ -129,7 +132,7 @@ def check_contiguity(self, tsamp: float) -> bool: return True -def edit_header(filename: str, key: str, value: int | float | str) -> None: +def edit_header(filename: str, key: str, value: float | str) -> None: """Edit a sigproc style header directly in place for the given file. Parameters @@ -153,7 +156,8 @@ def edit_header(filename: str, key: str, value: int | float | str) -> None: change the file on which it is being operated. """ if key not in header_keys: - raise ValueError(f"Key '{key}' is not a valid sigproc key.") + msg = f"Key '{key}' is not a valid sigproc key." + raise ValueError(msg) header = parse_header(filename) if key == "source_name" and isinstance(value, str): oldlen = len(header["source_name"]) @@ -163,14 +167,19 @@ def edit_header(filename: str, key: str, value: int | float | str) -> None: hdr.update({key: value}) new_hdr = encode_header(hdr) if header["hdrlen"] == len(new_hdr): - with open(filename, "rb+") as fp: + with Path(filename).open("rb+") as fp: fp.seek(0) fp.write(new_hdr) else: - raise ValueError("New header is too long/short for file") + msg = f"New header is too long/short for file {filename}" + raise ValueError(msg) -def parse_header_multi(filenames: str | list[str], check_contiguity: bool = True) -> dict: +def parse_header_multi( + filenames: str | list[str], + *, + check_contiguity: bool = True, +) -> dict: """Parse the metadata from Sigproc-style file/sequential files. Parameters @@ -196,7 +205,8 @@ def parse_header_multi(filenames: str | list[str], check_contiguity: bool = True sinfo.add_entry(FileInfo.from_dict(hdr)) if check_contiguity and not sinfo.check_contiguity(header["tsamp"]): - raise ValueError("Files are not contiguous") + msg = f"Files {filenames} are not contiguous" + raise ValueError(msg) header["stream_info"] = sinfo header["nsamples"] = header["stream_info"].get_combined("nsamples") @@ -221,14 +231,16 @@ def parse_header(filename: str) -> dict: IOError If file header is not in sigproc format """ - with open(filename, "rb") as fp: - header = {} + with Path(filename).open("rb") as fp: + header: dict[str, float | str] = {} try: key = _read_string(fp) except struct.error: - raise IOError("File Header is not in sigproc format... Is file empty?") + msg = f"File {filename} Header is not in sigproc format... Is file empty?." + raise OSError(msg) from None if key != "HEADER_START": - raise IOError("File Header is not in sigproc format") + msg = f"File {filename} Header is not in sigproc format." + raise OSError(msg) while True: key = _read_string(fp) if key == "HEADER_END": @@ -238,12 +250,16 @@ def parse_header(filename: str) -> dict: if key_fmt == "str": header[key] = _read_string(fp) else: - header[key] = struct.unpack(key_fmt, fp.read(struct.calcsize(key_fmt)))[0] + header[key] = struct.unpack(key_fmt, fp.read(struct.calcsize(key_fmt)))[ + 0 + ] header["hdrlen"] = fp.tell() fp.seek(0, 2) header["filelen"] = fp.tell() - header["datalen"] = header["filelen"] - header["hdrlen"] - header["nsamples"] = 8 * header["datalen"] // header["nbits"] // header["nchans"] + header["datalen"] = int(header["filelen"]) - int(header["hdrlen"]) + header["nsamples"] = ( + 8 * int(header["datalen"]) // int(header["nbits"]) // int(header["nchans"]) + ) fp.seek(0) header["filename"] = filename @@ -270,10 +286,11 @@ def match_header(header1: dict, header2: dict) -> None: if key in keys_nomatch or key not in header_keys: continue if value != header2[key]: - raise ValueError( - f'Header key "{key} = {value} and {header2[key]}"' - + f'do not match for file {header2["filename"]}' + msg = ( + f"Header key ({key} = {value}) and ({key} = {header2[key]}) " + f"do not match for file {header2['filename']}" ) + raise ValueError(msg) def encode_header(header: dict) -> bytes: @@ -286,7 +303,7 @@ def encode_header(header: dict) -> bytes: """ hdr_encoded = encode_key("HEADER_START") - for key in header.keys(): + for key in header: if key not in header_keys: continue hdr_encoded += encode_key(key, value=header[key], value_type=header_keys[key]) @@ -296,7 +313,9 @@ def encode_header(header: dict) -> bytes: def encode_key( - key: str, value: int | float | str | None = None, value_type: str = "str" + key: str, + value: float | str | None = None, + value_type: str = "str", ) -> bytes: """Encode given header key to a bytes string. @@ -342,23 +361,23 @@ def parse_radec(src_raj: float, src_dej: float) -> SkyCoord: :class:`~astropy.coordinates.SkyCoord` Astropy coordinate class """ - ho, mi = divmod(src_raj, 10000) # noqa: WPS432 + ho, mi = divmod(src_raj, 10000) mi, se = divmod(mi, 100) sign = -1 if src_dej < 0 else 1 - de, ami = divmod(abs(src_dej), 10000) # noqa: WPS432 + de, ami = divmod(abs(src_dej), 10000) ami, ase = divmod(ami, 100) radec_str = f"{int(ho)} {int(mi)} {se} {sign* int(de)} {int(ami)} {ase}" return SkyCoord(radec_str, unit=(units.hourangle, units.deg)) -def _read_string(fp): +def _read_string(fp: IO) -> str: """Read the next sigproc-format string in the file. Parameters ---------- - fp : file + fp : file object file object to read from. Returns diff --git a/sigpyproc/params.py b/sigpyproc/params.py index 48c4a9a..ddaa693 100644 --- a/sigpyproc/params.py +++ b/sigpyproc/params.py @@ -1,15 +1,19 @@ -import numpy as np +from __future__ import annotations + +from typing import TYPE_CHECKING +import numpy as np +from astropy import constants, units from bidict import bidict -from astropy import units, constants -from typing import Dict, Tuple, Callable +if TYPE_CHECKING: + from typing import Callable DM_CONSTANT_LK = 4.148808e3 # L&K Handbook of Pulsar Astronomy -DM_CONSTANT_MT = 1 / 0.000241 # TEMPO2 Manchester & Taylor (1972) # noqa: WPS432 +DM_CONSTANT_MT = 1 / 0.000241 # TEMPO2 Manchester & Taylor (1972) DM_CONSTANT_SI = ( - (constants.e.esu ** 2 / (2 * np.pi * constants.m_e * constants.c)).to( - units.s * units.MHz ** 2 * units.cm ** 3 / units.pc + (constants.e.esu**2 / (2 * np.pi * constants.m_e * constants.c)).to( + units.s * units.MHz**2 * units.cm**3 / units.pc, ) ).value # Precise SI constants @@ -104,7 +108,7 @@ 5: "complex spectrum", 6: "dedispersed subbands", 10: "PSRFITS", - } + }, ) # convert between types from the struct module and numpy @@ -115,7 +119,7 @@ telescope_lats_longs = {"Effelsberg": (50.52485, 6.883593)} # useful for creating inf files -presto_inf: Dict[str, Tuple[str, Callable, str]] = { +presto_inf: dict[str, tuple[str, Callable, str]] = { "Data file name without suffix": ("basename", str, "s"), "Telescope used": ("telescope", str, "s"), "Instrument used": ("backend", str, "s"), @@ -148,7 +152,7 @@ } sigpyproc_to_psrfits = dict( - zip(psrfits_to_sigpyproc.values(), psrfits_to_sigpyproc.keys()) + zip(psrfits_to_sigpyproc.values(), psrfits_to_sigpyproc.keys()), ) sigproc_to_tempo = {0: "g", 1: "3", 3: "f", 4: "7", 6: "1", 8: "g", 5: "8"} diff --git a/sigpyproc/timeseries.py b/sigpyproc/timeseries.py index a9a1e54..8c022fd 100644 --- a/sigpyproc/timeseries.py +++ b/sigpyproc/timeseries.py @@ -135,7 +135,7 @@ def running_mean(self, window: int = 10001) -> TimeSeries: """ if window < 1: raise ValueError("incorrect window size") - tim_ar = stats.running_mean(self, window) + tim_ar = stats.running_filter(self, window, filter_func="mean") return tim_ar.view(TimeSeries) def running_median(self, window: int = 10001) -> TimeSeries: @@ -155,7 +155,7 @@ def running_median(self, window: int = 10001) -> TimeSeries: ----- Window edges is dealt by reflecting about the edges of the time series. """ - tim_ar = stats.running_median(self, window) + tim_ar = stats.running_filter(self, window, filter_func="median") return tim_ar.view(TimeSeries) def apply_boxcar(self, width: int) -> TimeSeries: @@ -182,7 +182,7 @@ def apply_boxcar(self, width: int) -> TimeSeries: """ if width < 1: raise ValueError("incorrect boxcar window size") - mean_ar = stats.running_mean(self, width) * np.sqrt(width) + mean_ar = stats.running_filter(self, width, filter_func="mean") * np.sqrt(width) ref_bin = -width // 2 + 1 if width % 2 else -width // 2 boxcar_ar = np.roll(mean_ar, ref_bin) return boxcar_ar.view(TimeSeries) diff --git a/tests/test_bits.py b/tests/test_bits.py index 347b4ba..d68154f 100644 --- a/tests/test_bits.py +++ b/tests/test_bits.py @@ -1,30 +1,106 @@ -import pytest import numpy as np +import pytest + from sigpyproc.io import bits -class TestUnpacking(object): +class TestUnpacking: @pytest.mark.parametrize("nbits", [1, 2, 4]) - def test_unpack_empty(self, nbits): + def test_unpack_empty(self, nbits: int) -> None: input_arr = np.empty((0,), dtype=np.uint8) output = bits.unpack(input_arr, nbits=nbits) np.testing.assert_array_equal(input_arr, output) + @pytest.mark.parametrize("nbits", [1]) + @pytest.mark.parametrize("bitorder", ["big", "little"]) + @pytest.mark.parametrize("parallel", [False, True]) + def test_unpack_1bit(self, nbits: int, bitorder: str, parallel: bool) -> None: # noqa: FBT001 + rng = np.random.default_rng() + arr = rng.integers(255, size=2**10, dtype=np.uint8) + expected = np.unpackbits(arr, bitorder=bitorder) # type: ignore [arg-type] + output_buff = np.zeros(arr.size * 8 // nbits, dtype=np.uint8) + output_buff = bits.unpack( + arr, + nbits, + unpacked=output_buff, + bitorder=bitorder, + parallel=parallel, + ) + output_return = bits.unpack( + arr, + nbits, + unpacked=None, + bitorder=bitorder, + parallel=parallel, + ) + np.testing.assert_array_equal(output_buff, expected, strict=True) + np.testing.assert_array_equal(output_return, expected, strict=True) + + @pytest.mark.parametrize("nbits", [1]) + @pytest.mark.parametrize("bitorder", ["big", "little"]) + @pytest.mark.parametrize("parallel", [False, True]) + def test_pack_1bit(self, nbits: int, bitorder: str, parallel: bool) -> None: # noqa: FBT001 + rng = np.random.default_rng() + arr = rng.integers((1 << nbits) - 1, size=2**10, dtype=np.uint8) + expected = np.packbits(arr, bitorder=bitorder) # type: ignore [arg-type] + output_buff = np.zeros(arr.size // 8, dtype=np.uint8) + output_buff = bits.pack( + arr, + nbits, + packed=output_buff, + bitorder=bitorder, + parallel=parallel, + ) + output_return = bits.pack( + arr, + nbits, + packed=None, + bitorder=bitorder, + parallel=parallel, + ) + np.testing.assert_array_equal(output_buff, expected, strict=True) + np.testing.assert_array_equal(output_return, expected, strict=True) + @pytest.mark.parametrize("nbits", [1, 2, 4]) - def test_packunpack(self, nbits): - input_arr = np.arange(255, dtype=np.uint8) - output = bits.pack(bits.unpack(input_arr, nbits=nbits), nbits=nbits) - np.testing.assert_array_equal(input_arr, output) + @pytest.mark.parametrize("bitorder", ["big", "little"]) + @pytest.mark.parametrize("parallel", [False, True]) + def test_packunpack(self, nbits: int, bitorder: str, parallel: bool) -> None: # noqa: FBT001 + rng = np.random.default_rng() + arr = rng.integers(255, size=2**10, dtype=np.uint8) + tmp_unpack = np.zeros(arr.size * 8 // nbits, dtype=np.uint8) + tmp_unpack = bits.unpack( + arr, + nbits=nbits, + unpacked=tmp_unpack, + bitorder=bitorder, + parallel=parallel, + ) + output_buff = np.zeros_like(arr) + output_buff = bits.pack( + tmp_unpack, + nbits=nbits, + packed=output_buff, + bitorder=bitorder, + parallel=parallel, + ) + output_return = bits.pack( + bits.unpack(arr, nbits=nbits, bitorder=bitorder, parallel=parallel), + nbits=nbits, + bitorder=bitorder, + parallel=parallel, + ) + np.testing.assert_array_equal(output_buff, arr, strict=True) + np.testing.assert_array_equal(output_return, arr, strict=True) - def test_unpack_fail(self): + def test_unpack_fail(self) -> None: nbits = 10 input_arr = np.arange(255, dtype=np.uint8) with np.testing.assert_raises(ValueError): bits.unpack(input_arr, nbits=nbits) -class TestBitsInfo(object): - def test_nbits_4(self): +class TestBitsInfo: + def test_nbits_4(self) -> None: bitsinfo = bits.BitsInfo(4) np.testing.assert_equal(bitsinfo.nbits, 4) np.testing.assert_equal(bitsinfo.dtype, np.uint8) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 3b169ac..cde96e9 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -1,43 +1,70 @@ import numpy as np -from sigpyproc.core import kernels, stats, rfi +import pytest + +from sigpyproc.core import kernels, rfi, stats from sigpyproc.header import Header -class TestKernels(object): - def test_unpack1_8(self): - input_arr = np.array([0, 2, 7, 23], dtype=np.uint8) - expected_bit1 = np.array( - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, - 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1], dtype=np.uint8 +class TestKernels: + def test_unpack1_8(self) -> None: + input_arr = np.array([7, 23], dtype=np.uint8) + expected_big = np.array( + [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1], + dtype=np.uint8, ) - unpacked = np.empty_like(expected_bit1) - np.testing.assert_array_equal(expected_bit1, kernels.unpack1_8(input_arr, unpacked)) - - def test_unpack2_8(self): - input_arr = np.array([0, 2, 7, 23], dtype=np.uint8) - expected_bit2 = np.array( - [0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 3, 0, 1, 1, 3], dtype=np.uint8 + expected_little = np.array( + [1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0], + dtype=np.uint8, ) - unpacked = np.empty_like(expected_bit2) - np.testing.assert_array_equal(expected_bit2, kernels.unpack2_8(input_arr, unpacked)) + unpacked_big = np.empty_like(expected_big) + unpacked_little = np.empty_like(expected_little) + kernels.unpack1_8_big(input_arr, unpacked_big) + kernels.unpack1_8_little(input_arr, unpacked_little) + np.testing.assert_array_equal(unpacked_big, expected_big, strict=True) + np.testing.assert_array_equal(unpacked_little, expected_little, strict=True) - def test_unpack4_8(self): - input_arr = np.array([0, 2, 7, 23], dtype=np.uint8) - expected_bit4 = np.array([0, 0, 0, 2, 0, 7, 1, 7], dtype=np.uint8) - unpacked = np.empty_like(expected_bit4) - np.testing.assert_array_equal(expected_bit4, kernels.unpack4_8(input_arr, unpacked)) + def test_unpack2_8(self) -> None: + input_arr = np.array([7, 23], dtype=np.uint8) + expected_big = np.array( + [0, 0, 1, 3, 0, 1, 1, 3], + dtype=np.uint8, + ) + expected_little = np.array( + [3, 1, 0, 0, 3, 1, 1, 0], + dtype=np.uint8, + ) + unpacked_big = np.empty_like(expected_big) + unpacked_little = np.empty_like(expected_little) + kernels.unpack2_8_big(input_arr, unpacked_big) + kernels.unpack2_8_little(input_arr, unpacked_little) + np.testing.assert_array_equal(unpacked_big, expected_big, strict=True) + np.testing.assert_array_equal(unpacked_little, expected_little, strict=True) - def test_pack2_8(self): - input_arr = np.arange(255, dtype=np.uint8) - unpacked = np.empty(input_arr.size * 4, dtype="ubyte") - output = kernels.pack2_8(kernels.unpack2_8(input_arr, unpacked)) - np.testing.assert_array_equal(input_arr, output) + def test_unpack4_8(self) -> None: + input_arr = np.array([7, 23], dtype=np.uint8) + expected_big = np.array([0, 7, 1, 7], dtype=np.uint8) + expected_little = np.array([7, 0, 7, 1], dtype=np.uint8) + unpacked_big = np.empty_like(expected_big) + unpacked_little = np.empty_like(expected_little) + kernels.unpack4_8_big(input_arr, unpacked_big) + kernels.unpack4_8_little(input_arr, unpacked_little) + np.testing.assert_array_equal(unpacked_big, expected_big, strict=True) + np.testing.assert_array_equal(unpacked_little, expected_little, strict=True) - def test_pack4_8(self): - input_arr = np.arange(255, dtype=np.uint8) - unpacked = np.empty(input_arr.size * 2, dtype="ubyte") - output = kernels.pack4_8(kernels.unpack4_8(input_arr, unpacked)) - np.testing.assert_array_equal(input_arr, output) + @pytest.mark.parametrize("nbits", [1, 2, 4]) + @pytest.mark.parametrize("bitorder", ["big", "little"]) + @pytest.mark.parametrize("parallel", [False, True]) + def test_pack(self, nbits: int, bitorder: str, parallel: bool) -> None: # noqa: FBT001 + rng = np.random.default_rng() + arr = rng.integers(255, size=2**10, dtype=np.uint8) + parallel_str = "" if parallel else "_serial" + unpack_func = getattr(kernels, f"unpack{nbits:d}_8_{bitorder}{parallel_str}") + pack_func = getattr(kernels, f"pack{nbits:d}_8_{bitorder}{parallel_str}") + unpacked = np.zeros(arr.size * 8 // nbits, dtype=np.uint8) + unpack_func(arr, unpacked) + packed = np.empty_like(arr) + pack_func(unpacked, packed) + np.testing.assert_array_equal(packed, arr, strict=True) class TestStats(object): diff --git a/tests/test_sigproc.py b/tests/test_sigproc.py index 577e7aa..2227bd6 100644 --- a/tests/test_sigproc.py +++ b/tests/test_sigproc.py @@ -1,7 +1,8 @@ -import pytest -import struct import shutil +import struct + import numpy as np +import pytest from astropy.coordinates import SkyCoord from sigpyproc.io import sigproc @@ -113,10 +114,10 @@ def test_stream_info_add_entry(self): assert len(stream_info.entries) == 1 assert stream_info.entries[0] == file_info - def test_stream_info_add_entry_invalid(self): + def test_stream_info_add_entry_invalid(self) -> None: stream_info = sigproc.StreamInfo() - with pytest.raises(ValueError): - stream_info.add_entry("invalid") + with pytest.raises(TypeError): + stream_info.add_entry("invalid") # type: ignore [arg-type] def test_stream_info_check_contiguity_valid(self): file_info1 = sigproc.FileInfo(