Skip to content

Commit

Permalink
feat: exposing wrapper for conventional iterative recon
Browse files Browse the repository at this point in the history
  • Loading branch information
mcencini committed Mar 11, 2024
1 parent 7d303d3 commit 71f00a4
Show file tree
Hide file tree
Showing 10 changed files with 360 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/deepmr/_signal/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def resample(input, oshape, filt=True, polysmooth=False):

# if required, apply filtering
if filt is not None:
freq *= filt
freq *= filt.to(freq.device)

# transform back
output = _ifftc(freq, axes)
Expand Down
6 changes: 2 additions & 4 deletions src/deepmr/linops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
from .fft import * # noqa
from .nufft import * # noqa

__all__ = ["EncodingOp"]
__all__ = []
__all__.extend(_base.__all__)
__all__.extend(_coil.__all__)
__all__.extend(_fft.__all__)
__all__.extend(_nufft.__all__)


def EncodingOp():
pass

31 changes: 31 additions & 0 deletions src/deepmr/linops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def prox_l2(self, z, y, gamma):
H = lambda x: self.A(x) + 1 / gamma * x
x = conjugate_gradient(H, b, self.max_iter, self.tol)
return x

def maxeig(self, input, max_iter=10, tol=1e-6):
x = torch.randn(input.shape)
return power_iter(self.A, x, max_iter, tol)


# %% local utils
Expand Down Expand Up @@ -170,3 +174,30 @@ def A(x):

return x

@torch.no_grad()
def power_iter(A, x0, max_iter=2, tol=1e-6):
r"""
Use power iteration to calculate the spectral norm of a LinearMap.
From MIRTorch (https://github.com/guanhuaw/MIRTorch/blob/master/mirtorch/alg/spectral.py)
Args:
A: a LinearMap
x0: initial guess of singular vector corresponding to max singular value
max_iter: maximum number of iterations
tol: stopping tolerance
Returns:
The spectral norm (sig1) and the principal right singular vector (x)
"""

x = x0
max_eig = float("inf")
for iter in range(max_iter):
Ax = A(x)
max_eig = torch.norm(Ax)
x = x / max_eig

return max_eig

7 changes: 7 additions & 0 deletions src/deepmr/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

# from .admm import * # noqa
from .data_fidelity import * # noqa
from deepinv.optim.optim_iterators import OptimIterator # noqa
from deepinv.optim.optim_iterators import GDIteration # noqa
from deepinv.optim.optim_iterators import PGDIteration # noqa
from deepinv.optim.optim_iterators import CPIteration # noqa
from deepinv.optim.optim_iterators import DRSIteration # noqa
from deepinv.optim.optim_iterators import HQSIteration # noqa

__all__ = []
__all__.extend(_data_fidelity.__all__)
__all__.extend(["OptimIterator, GDIteration", "PGDIteration", "CPIteration", "DRSIteration", "HQSIteration"])
14 changes: 7 additions & 7 deletions src/deepmr/prox/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from deepinv.optim.prior import PnP


def TVPrior(ndim, device=None, verbose=False, n_it_max=1000, crit=1e-5, x2=None, u2=None):
def TVPrior(ndim, device=None, verbose=False, niter=100, crit=1e-5, x2=None, u2=None):
r"""
Proximal operator of the isotropic Total Variation operator.
Expand Down Expand Up @@ -39,7 +39,7 @@ def TVPrior(ndim, device=None, verbose=False, n_it_max=1000, crit=1e-5, x2=None,
Device on which the wavelet transform is computed. Default is ``None``.
verbose : bool, optional
Whether to print computation details or not. Default: ``False``.
n_it_max : int, optional,
niter : int, optional,
Maximum number of iterations. Default: ``1000``.
crit : float, optional
Convergence criterion. Default: 1e-5.
Expand All @@ -55,10 +55,10 @@ def TVPrior(ndim, device=None, verbose=False, n_it_max=1000, crit=1e-5, x2=None,
variation image denoising and deblurring problems", IEEE T. on Image Processing. 18(11), 2419-2434, 2009.
"""
return PnP(denoiser=ComplexTVDenoiser(ndim, device, verbose, n_it_max, crit, x2, u2))
return PnP(denoiser=ComplexTVDenoiser(ndim, device, verbose, niter, crit, x2, u2))


def tv_denoise(input, ndim, ths=0.1, device=None, verbose=False, n_it_max=1000, crit=1e-5, x2=None, u2=None):
def tv_denoise(input, ndim, ths=0.1, device=None, verbose=False, niter=100, crit=1e-5, x2=None, u2=None):
r"""
Apply isotropic Total Variation denoising.
Expand Down Expand Up @@ -93,7 +93,7 @@ def tv_denoise(input, ndim, ths=0.1, device=None, verbose=False, n_it_max=1000,
Device on which the wavelet transform is computed. Default is ``None``.
verbose : bool, optional
Whether to print computation details or not. Default: ``False``.
n_it_max : int, optional,
niter : int, optional,
Maximum number of iterations. Default: ``1000``.
crit : float, optional
Convergence criterion. Default: 1e-5.
Expand All @@ -114,7 +114,7 @@ def tv_denoise(input, ndim, ths=0.1, device=None, verbose=False, n_it_max=1000,
Denoised image of shape (..., n_ndim, ..., n_0).
"""
TV = ComplexTVDenoiser(ndim, device, verbose, n_it_max, crit, x2, u2)
TV = ComplexTVDenoiser(ndim, device, verbose, niter, crit, x2, u2)
return TV(input, ths)


Expand Down Expand Up @@ -144,7 +144,7 @@ def forward(self, input, ths):
if self.denoiser.device is None:
device = idevice
else:
self.denoiser.device = device
device = self.denoiser.device

# get input shape
ndim = self.denoiser.ndim
Expand Down
3 changes: 2 additions & 1 deletion src/deepmr/prox/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class ComplexWaveletDenoiser(torch.nn.Module):
def __init__(self, ndim, wv, device, p, level, *args, **kwargs):
super().__init__()
self.denoiser = _WaveletPrior(level=level, wv=wv, p=p, device=device, wvdim=ndim, *args, **kwargs)
self.denoiser.device = device

def forward(self, input, ths):

Expand All @@ -121,7 +122,7 @@ def forward(self, input, ths):
if self.denoiser.device is None:
device = idevice
else:
self.denoiser.device = device
device = self.denoiser.device

# get input shape
ndim = self.denoiser.wvdim
Expand Down
4 changes: 4 additions & 0 deletions src/deepmr/recon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
"""

from . import calib as _calib
from . import alg as _alg

from .calib import * # noqa
from .alg import * # noqa

__all__ = []
__all__.extend(_calib.__all__)
__all__.extend(_alg.__all__)

11 changes: 11 additions & 0 deletions src/deepmr/recon/alg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Sub-package containing image reconstruction wrapper routines."""

from . import linop as _linop
from . import classic_recon as _classic_recon

from .linop import * # noqa
from .classic_recon import * # noqa

__all__ = []
__all__.extend(_linop.__all__)
__all__.extend(_classic_recon.__all__)
188 changes: 188 additions & 0 deletions src/deepmr/recon/alg/classic_recon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Classical iterative reconstruction wrapper."""

__all__ = ["recon_lstsq"]

import copy

import numpy as np
import torch

import deepinv as dinv

from ... import optim as _optim
from ... import prox as _prox
from .. import calib as _calib
from . import linop as _linop

from numba.core.errors import NumbaPerformanceWarning
import warnings

warnings.simplefilter('ignore', category=NumbaPerformanceWarning)

def recon_lstsq(data, head, mask=None, niter=1, prior=None, prior_ths=0.01, prior_params=None, lamda=0.0, stepsize=None, basis=None, nsets=1, device=None, cal_data=None, toeplitz=True):
"""
Classical MR reconstruction.
Parameters
----------
data : np.ndarray | torch.Tensor
Input k-space data of shape ``(nslices, ncoils, ncontrasts, nviews, nsamples)``.
head : deepmr.Header
DeepMR acquisition header, containing ``traj``, ``shape`` and ``dcf``.
mask : np.ndarray | torch.Tensor, optional
Sampling mask for Cartesian imaging.
Expected shape is ``(ncontrasts, nviews, nsamples)``.
The default is ``None``.
niter : int, optional
Number of recon iterations. If single iteration,
perform simple zero-filled recon. The default is ``1``.
prior : str | deepinv.optim.Prior, optional
Prior for image regularization. If string, it must be one of the following:
* ``"L1Wav"``: L1 Wavelet regularization.
* ``"TV"``: Total Variation regularization.
The default is ``None`` (no regularizer).
prior_ths : float, optional
Threshold for denoising in regularizer. The default is ``0.01``.
prior_params : dict, optional
Parameters for Prior initializations.
See :func:`deepmr.prox`.
The defaul it ``None`` (use each regularizer default parameters).
lamda : float, optional
Tikonhov regularization strength. If 0.0, do not apply
Tikonhov regularization. The default is ``0.0``.
stepsize : float, optional
Iterations step size. If not provided, estimate from Encoding
operator maximum eigenvalue. The default is ``None``.
basis : np.ndarray | torch.Tensor, optional
Low rank subspace basis of shape ``(ncontrasts, ncoeffs)``. The default is ``None``.
nsets : int, optional
Number of coil sensitivity sets of maps. The default is ``1.
device : str, optional
Computational device. The default is ``None`` (same as ``data``).
cal_data : np.ndarray | torch.Tensor, optional
Calibration dataset for coil sensitivity estimation.
The default is ``None`` (use center region of ``data``).
toeplitz : bool, optional
Use Toeplitz approach for normal equation. The default is ``False``.
Returns
-------
img np.ndarray | torch.Tensor
Reconstructed image of shape:
* 2D Cartesian: ``(nslices, ncontrasts, ny, nx).
* 2D Non Cartesian: ``(nslices, ncontrasts, ny, nx).
* 2D Non Cartesian: ``(nslices, ncontrasts, ny, nx).
* 3D Non Cartesian: ``(ncontrasts, nz, ny, nx).
"""
if isinstance(data, np.ndarray):
data = torch.as_tensor(data)
isnumpy = True
else:
isnumpy = False

if device is None:
device = data.device
data = data.to(device)

if head.dcf is not None:
head.dcf = head.dcf.to(device)

# toggle off topelitz for non-iterative
if niter == 1:
toeplitz = False

# get ndim
if head.traj is not None:
ndim = head.traj.shape[-1]
else:
ndim = 2 # assume 3D data already decoupled along readout

# build encoding operator
E, EHE = _linop.EncodingOp(data, mask, head.traj, head.dcf, head.shape, nsets, basis, device, cal_data, toeplitz)

# perform zero-filled reconstruction
img = E.H(head.dcf**0.5 * data[:, None, ...])

# if non-iterative, just perform linear recon
if niter == 1:
output = img
if isnumpy:
output = output.numpy(force=True)
return output

# rescale
img = _calib.intensity_scaling(img, ndim=ndim)

# if no prior is specified, use CG recon
if prior is None:
output = EHE.solve(img, max_iter=niter, lamda=lamda)
if isnumpy:
output = output.numpy(force=True)
return output

# if a single prior is specified, use PDG
if isinstance(prior, (list, tuple)) is False:

# default prior params
if prior_params is None:
prior_params = {}

# modify EHE
if lamda != 0.0:
img = img / lamda
prior_ths = prior_ths / lamda
tmp = copy.deepcopy(EHE)
f = lambda x : tmp.A(x) + lamda * x
EHE.A = f
EHE.A_adjoint = f
else:
lamda = 1.0

# compute spectral norm
if stepsize is None:
max_eig = EHE.maxeig(img, max_iter=1)
if max_eig == 0.0:
stepsize = 1.0
else:
stepsize = 1.0 / float(max_eig)

# solver parameters
params_algo = {"stepsize": stepsize, "g_param": prior_ths, "lambda": lamda}

# select the data fidelity term
data_fidelity = _optim.L2()

# Get Wavelet Prior
prior = _get_prior(prior, ndim, device, **prior_params)

# instantiate the algorithm class to solve the IP problem.
solver = dinv.optim.optim_builder(
iteration="PGD",
prior=prior,
data_fidelity=data_fidelity,
early_stop=True,
max_iter=niter,
params_algo=params_algo,
)

output = solver(img, EHE) * lamda
if isnumpy:
output = output.numpy(force=True)
return output


# %% local utils
def _get_prior(ptype, ndim, device, **params):
if isinstance(ptype, str):
if ptype == "L1Wave":
return _prox.WaveletPrior(ndim, device=device, **params)
elif ptype == "TV":
return _prox.TVPrior(ndim, device=device, **params)
else:
raise ValueError(f"Prior type = {ptype} not recognized; either specify 'L1Wave', 'TV' or 'deepinv.optim.Prior' object.")
else:
raise NotImplementedError("Direct prior object not implemented.")
Loading

0 comments on commit 71f00a4

Please sign in to comment.