# signal_base.py
"""
Defines the signal base classes and metaclasses. All signals will then be
derived from these base classes.
"""
import collections
try:
from collections.abc import Sequence
except:
from collections import Sequence
import itertools
import logging
import numpy as np
import scipy.linalg as sl
import scipy.sparse as sps
import six
# these are defined in parameter.py, but currently imported
# in various places from signal_base.py
from enterprise.signals.parameter import Function # noqa: F401
from enterprise.signals.parameter import function # noqa: F401
from enterprise.signals.parameter import ConstantParameter
from enterprise.signals.utils import KernelMatrix
# logging.basicConfig(format="%(levelname)s: %(name)s: %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)
try:
from sksparse.cholmod import cholesky
except ImportError:
msg = "No sksparse library. Using scipy instead!"
logger.warning(msg)
[docs] class cholesky(object):
def __init__(self, x):
if sps.issparse(x):
x = x.toarray()
self.cf = sl.cho_factor(x)
def __call__(self, other):
return sl.cho_solve(self.cf, other)
[docs] def logdet(self):
return np.sum(2 * np.log(np.diag(self.cf[0])))
[docs] def inv(self):
return sl.cho_solve(self.cf, np.eye(len(self.cf[0])))
[docs]@six.add_metaclass(MetaSignal)
class Signal(object):
"""Base class for Signal objects."""
def __init__(self, psr):
self.psrname = psr.name
@property
def params(self):
# return only nonconstant parameters
return [par for par in self._params.values() if not isinstance(par, ConstantParameter)]
@property
def param_names(self):
ret = []
for p in self.params:
if p.size:
for ii in range(0, p.size):
ret.append(p.name + "_{}".format(ii))
else:
ret.append(p.name)
return ret
[docs] def get(self, parname, params={}):
try:
return params[self._params[parname].name]
except KeyError:
return self._params[parname].value
[docs] def set_default_params(self, params):
"""Set default parameters."""
for kw, par in self._params.items():
if par.name in params and isinstance(par, ConstantParameter):
msg = "Setting {} to {}".format(par.name, params[par.name])
logger.info(msg)
self._params[kw].value = params[par.name]
elif par.name not in params and isinstance(par, ConstantParameter):
if par.value is None:
msg = "{} not set! Check parameter dict.".format(par.name)
logger.warning(msg)
[docs] def get_ndiag(self, params):
"""Returns the diagonal of the white noise vector `N`.
This method also supports block diagonal sparse matrices.
"""
return None
[docs] def get_delay(self, params):
"""Returns the waveform of a deterministic signal."""
return 0
[docs] def get_basis(self, params=None):
"""Returns the basis array of shape N_toa x N_basis."""
return None
[docs] def get_phi(self, params):
"""Returns a diagonal covariance matrix of the basis amplitudes."""
return None
[docs] def get_phiinv(self, params):
"""Returns inverse of the covaraince of basis amplitudes."""
return None
[docs] def get_logsignalprior(self, params):
"""Returns an additional prior/likelihood terms associated with a signal."""
return 0
[docs]class CommonSignal(Signal):
"""Base class for CommonSignal objects."""
[docs] def get_phiinv(self, params):
msg = "You probably shouldn't be calling get_phiinv() "
msg += "on a common red-noise signal."
raise RuntimeError(msg)
[docs] @classmethod
def get_phicross(cls, signal1, signal2, params):
return None
[docs]class LogLikelihood(object):
def __init__(self, pta):
self.pta = pta
def _make_sigma(self, TNTs, phiinv):
return sps.block_diag(TNTs, "csc") + sps.csc_matrix(phiinv)
def __call__(self, xs, phiinv_method="cliques"):
# map parameter vector if needed
params = xs if isinstance(xs, dict) else self.pta.map_params(xs)
loglike = 0
# phiinvs will be a list or may be a big matrix if spatially
# correlated signals
TNrs = self.pta.get_TNr(params)
TNTs = self.pta.get_TNT(params)
phiinvs = self.pta.get_phiinv(params, logdet=True, method=phiinv_method)
# get -0.5 * (rNr + logdet_N) piece of likelihood
# the np.sum here is needed because each pulsar returns a 2-tuple
loglike += -0.5 * np.sum([ell for ell in self.pta.get_rNr_logdet(params)])
# get extra prior/likelihoods
loglike += sum(self.pta.get_logsignalprior(params))
# red noise piece
if self.pta._commonsignals:
phiinv, logdet_phi = phiinvs
Sigma = self._make_sigma(TNTs, phiinv)
TNr = np.concatenate(TNrs)
try:
cf = cholesky(Sigma)
expval = cf(TNr)
except:
return -np.inf
logdet_sigma = cf.logdet()
loglike += 0.5 * (np.dot(TNr, expval) - logdet_sigma - logdet_phi)
else:
for TNr, TNT, pl in zip(TNrs, TNTs, phiinvs):
if TNr is None:
continue
phiinv, logdet_phi = pl
Sigma = TNT + (np.diag(phiinv) if phiinv.ndim == 1 else phiinv)
try:
cf = sl.cho_factor(Sigma)
expval = sl.cho_solve(cf, TNr)
except:
return -np.inf
logdet_sigma = np.sum(2 * np.log(np.diag(cf[0])))
loglike += 0.5 * (np.dot(TNr, expval) - logdet_sigma - logdet_phi)
return loglike
[docs]class PTA(object):
def __init__(self, init, lnlikelihood=LogLikelihood):
if isinstance(init, Sequence):
self._signalcollections = list(init)
else:
self._signalcollections = [init]
self.lnlikelihood = lnlikelihood
# set signal dictionary
self._set_signal_dict()
def __add__(self, other):
if hasattr(other, "_signalcollections"):
return PTA(self._signalcollections + other._signalcollections, lnlikelihood=self.lnlikelihood)
else:
return PTA(self._signalcollections + [other], lnlikelihood=self.lnlikelihood)
@property
def params(self):
ret = set()
for signalcollection in self._signalcollections:
for param in signalcollection.params:
for par in param.params:
ret.add(par)
return sorted(list(ret), key=lambda par: par.name)
# return sorted({par for signalcollection in self._signalcollections
# for par in signalcollection.params},
# key=lambda par: par.name)
@property
def param_names(self):
ret = []
for p in self.params:
if p.size:
for ii in range(0, p.size):
ret.append(p.name + "_{}".format(ii))
else:
ret.append(p.name)
return ret
@property
def pulsarmodels(self):
return self._signalcollections
[docs] def get_TNr(self, params):
return [signalcollection.get_TNr(params) for signalcollection in self._signalcollections]
[docs] def get_TNT(self, params):
return [signalcollection.get_TNT(params) for signalcollection in self._signalcollections]
[docs] def get_rNr_logdet(self, params):
return [signalcollection.get_rNr_logdet(params) for signalcollection in self._signalcollections]
[docs] def get_residuals(self):
return [signalcollection._residuals for signalcollection in self._signalcollections]
[docs] def get_ndiag(self, params={}):
return [signalcollection.get_ndiag(params) for signalcollection in self._signalcollections]
[docs] def get_delay(self, params={}):
return [signalcollection.get_delay(params) for signalcollection in self._signalcollections]
[docs] def get_logsignalprior(self, params):
return [signalcollection.get_logsignalprior(params) for signalcollection in self._signalcollections]
[docs] def set_default_params(self, params):
for sc in self._signalcollections:
sc.set_default_params(params)
[docs] def get_basis(self, params={}):
return [signalcollection.get_basis(params) for signalcollection in self._signalcollections]
@property
def _lnlikelihood(self):
# instantiate on first use
if not hasattr(self, "_lnlike"):
self._lnlike = self.lnlikelihood(self)
return self._lnlike
[docs] def get_lnlikelihood(self, params, **kwargs):
return self._lnlikelihood(params, **kwargs)
@property
def _commonsignals(self):
# cache the computation if we don't have it yet
if not hasattr(self, "_cs"):
commonsignals = collections.defaultdict(collections.OrderedDict)
for signalcollection in self._signalcollections:
# TODO: need a better signal that a
# signalcollection provides a basis
if signalcollection._Fmat is not None:
for signal in signalcollection._signals:
# if the CommonSignal is coefficient based we don't
# need to worry about it for get_phi and get_phiinv
if isinstance(signal, CommonSignal) and not getattr(signal, "_coefficients", {}):
commonsignals[signal.__class__][signal] = signalcollection
# drop common signals that appear only once
self._cs = {csclass: csdict for csclass, csdict in commonsignals.items() if len(csdict) > 1}
return self._cs
# return a dictionary (indexed by SignalCollection) of Python slices
# corresponding to the span of each pulsar within a Phi matrix
def _get_slices(self, phivecs):
ret, offset = {}, 0
for sc, phivec in zip(self._signalcollections, phivecs):
# assume phi is either a column vector or a square matrix
stop = 0 if phivec is None else phivec.shape[0]
ret[sc] = slice(offset, offset + stop)
offset = ret[sc].stop
return ret
[docs] def get_phiinv(self, params, logdet=False, method="cliques"):
if method == "cliques":
return self.get_phiinv_byfreq_cliques(params, logdet)
elif method == "partition":
return self.get_phiinv_byfreq_partition(params, logdet)
elif method == "sparse":
return self.get_phiinv_sparse(params, logdet)
else:
raise NotImplementedError
[docs] def get_phiinv_sparse(self, params, logdet=False):
phi = self.get_phi(params)
if isinstance(phi, list):
return [None if phivec is None else phivec.inv(logdet) for phivec in phi]
else:
phisparse = sps.csc_matrix(phi)
cf = cholesky(phisparse)
if logdet:
return (cf.inv(), cf.logdet())
else:
return cf.inv()
[docs] def get_phiinv_byfreq_partition(self, params, logdet=False):
phivecs = [signalcollection.get_phi(params) for signalcollection in self._signalcollections]
# if we found common signals, we'll return a big phivec matrix,
# otherwise a list of phivec vectors (some of which possibly None)
if self._commonsignals:
slices = self._get_slices(phivecs)
# TODO: This is messy, maybe we should clean up
phis = [phivec for phivec in phivecs if phivec is not None]
if np.any([phivec.ndim == 2 for phivec in phis]):
phiinvs = [phivec.inv(logdet) for phivec in phis]
phiinv_full = [np.diag(phi[0]) if phi[0].ndim == 1 else phi[0] for phi in phiinvs]
phiinv = sl.block_diag(*phiinv_full)
if logdet:
ld = np.sum([pi[1] for pi in phiinvs])
phidiag = np.concatenate([np.diag(phi) if phi.ndim == 2 else phi for phi in phis])
else:
phidiag = np.concatenate(phis)
phiinv = np.diag(1.0 / phidiag)
if logdet:
ld = np.sum(np.log(phidiag))
# this will only work if all common signals are shared among all
# the pulsars and share the same basis
invert = None
for csclass, csdict in self._commonsignals.items():
for i, (cs1, csc1) in enumerate(csdict.items()):
for j, (cs2, csc2) in enumerate(csdict.items()):
if j <= i:
continue
# hoping they're all the same...
crossdiag = csclass.get_phicross(cs1, cs2, params)
if invert is None:
invert = np.zeros((len(crossdiag), len(csdict), len(csdict)), "d")
if crossdiag.ndim == 2:
raise NotImplementedError(
"get_phiinv with method='partition' does not " "support dense phi matrices."
)
invert[:, i, j] += crossdiag
invert[:, j, i] += crossdiag
invert[:, i, i] += phidiag[slices[csc1]][csc1._idx[cs1]]
if logdet:
ld -= np.sum(np.log(phidiag[slices[csc1]][csc1._idx[cs1]]))
for k in range(len(crossdiag)):
cf = sl.cho_factor(invert[k, :, :])
invert[k, :, :] = sl.cho_solve(cf, np.eye(invert[k, :, :].shape[0]))
if logdet:
ld += np.sum(2 * np.log(np.diag(cf[0])))
csdict = list(self._commonsignals.values())[0]
for i, (cs1, csc1) in enumerate(csdict.items()):
block1, idx1 = slices[csc1], csc1._idx[cs1]
for j, (cs2, csc2) in enumerate(csdict.items()):
if j < i:
continue
block2, idx2 = slices[csc2], csc2._idx[cs2]
phiinv[block1, block2][idx1, idx2] = invert[:, i, j]
phiinv[block2, block1][idx2, idx1] = invert[:, i, j]
if logdet:
return phiinv, ld
else:
return phiinv
else:
return [None if phivec is None else phivec.inv(logdet) for phivec in phivecs]
[docs] def get_phiinv_byfreq_cliques(self, params, logdet=False, cholesky=False):
phi = self.get_phi(params, cliques=True)
if isinstance(phi, list):
return [None if phivec is None else phivec.inv(logdet) for phivec in phi]
else:
ld = 0
# first invert all the cliques
for clcount in range(self._clcount):
idx = self._cliques == clcount
if np.any(idx):
idx2 = np.ix_(idx, idx)
if cholesky:
cf = sl.cho_factor(phi[idx2])
if logdet:
ld += 2.0 * np.sum(np.log(np.diag(cf[0])))
phi[idx2] = sl.cho_solve(cf, np.identity(cf[0].shape[0]))
else:
phi2 = phi[idx2]
if logdet:
ld += np.linalg.slogdet(phi2)[1]
phi[idx2] = np.linalg.inv(phi2)
# then do the pure diagonal terms
idx = self._cliques == -1
if logdet:
ld += np.sum(np.log(phi[idx, idx]))
phi[idx, idx] = 1.0 / phi[idx, idx]
return (phi, ld) if logdet else phi
# we use "cliques" to account for sparse non-diagonal Phi matrices
# for each value in self._cliques, the matrix indices with that value form
# an independent submatrix that can be inverted separately
# reset clique index
def _resetcliques(self, n):
self._cliques = -1 * np.ones(n)
self._clcount = 0
# update clique index by considering a common signal under
# the assumption that the corresponding "big-Phi" matrix is block diagonal
def _setcliques(self, slices, csdict):
# each column in idxmatrix (mind the .T) corresponds to the indices
# that participate in a common signal for a given pulsar
idxmatrix = np.array([csc._idx[cs] for cs, csc in csdict.items()]).T
# each row in the updated idxmatrix corresponds to a set of "global"
# Phi indices that are correlated across pulsars
idxmatrix = idxmatrix + np.array([slices[csc].start for cs, csc in csdict.items()])
# loop over vectors of common-signal-correlated global-indices
for idxs in idxmatrix:
# find the existing cliques assigned to these global indices
allidx = set(self._cliques[idxs])
maxidx = max(allidx)
if maxidx == -1:
# if no clique is found, create a new one, and assign it
# to the indices in idx
self._cliques[idxs] = self._clcount
# I don't think this code is ever exercised...
# if maxidx == -1, then allidx = [-1]
if len(allidx) > 1:
self._cliques[np.in1d(self._cliques, allidx)] = self._clcount
self._clcount = self._clcount + 1
else:
# if we find at least one clique, assign all indices in idx
# to the maximum clique index
self._cliques[idxs] = maxidx
# since cliques are "contagious", reassign all the other
# clique indices that we found to maxidx
if len(allidx) > 1:
self._cliques[np.in1d(self._cliques, allidx)] = maxidx
# add cliques from individual pulsar phis; these will never overlap
# TO DO: at this point Phi could be defined as a smarter KernelMatrix!
def _setpulsarcliques(self, slices, phis):
for sc, phi in zip(self._signalcollections, phis):
if phi is not None:
for clindex in range(getattr(phi, "_clcount", 0)):
phiind = np.where(phi._cliques == clindex)[0]
if len(phiind) > 0:
try:
self._cliques[slices[sc].start + phiind] = self._clcount
self._clcount = self._clcount + 1
except Exception: # pragma: no cover
logger.exception("Exception raised in computing cliques")
logger.info(self._cliques.shape)
logger.info("phiind", phiind, len(phiind))
logger.info(slices)
raise
[docs] def get_phi(self, params, cliques=False):
phis = [signalcollection.get_phi(params) for signalcollection in self._signalcollections]
# if we found common signals, we'll return a big phivec matrix,
# otherwise a list of phivec vectors (some of which possibly None)
if self._commonsignals:
if np.any([phi.ndim == 2 for phi in phis if phi is not None]):
# if we have any dense matrices,
Phi = sl.block_diag(*[np.diag(phi) if phi.ndim == 1 else phi for phi in phis if phi is not None])
else:
Phi = np.diag(np.concatenate([phi for phi in phis if phi is not None]))
# get a dictionary of slices locating each pulsar in Phi matrix
slices = self._get_slices(phis)
# self._cliques is a vector of the same size as the Phi matrix
# for each Phi index i, self._cliques[i] is -1 if row/column
# belong to no clique, or it gives the clique number otherwise
if cliques:
self._resetcliques(Phi.shape[0])
self._setpulsarcliques(slices, phis)
# iterate over all common signal classes
for csclass, csdict in self._commonsignals.items():
# first figure out which indices are used in this common signal
# and update the clique index
if cliques:
self._setcliques(slices, csdict)
# now iterate over all pairs of common signal instances
pairs = itertools.combinations(csdict.items(), 2)
for (cs1, csc1), (cs2, csc2) in pairs:
crossdiag = csclass.get_phicross(cs1, cs2, params)
block1, idx1 = slices[csc1], csc1._idx[cs1]
block2, idx2 = slices[csc2], csc2._idx[cs2]
if crossdiag.ndim == 1:
Phi[block1, block2][idx1, idx2] += crossdiag
Phi[block2, block1][idx2, idx1] += crossdiag
else:
Phi[block1, block2][np.ix_(idx1, idx2)] += crossdiag
Phi[block2, block1][np.ix_(idx2, idx1)] += crossdiag
return Phi
else:
return phis
[docs] def map_params(self, xs):
ret = {}
ct = 0
for p in self.params:
n = p.size if p.size else 1
ret[p.name] = xs[ct : ct + n] if n > 1 else float(xs[ct])
ct += n
return ret
[docs] def get_lnprior(self, params):
# map parameter vector if needed
params = params if isinstance(params, dict) else self.map_params(params)
return np.sum([p.get_logpdf(params=params) for p in self.params])
@property
def pulsars(self):
return [p.psrname for p in self._signalcollections]
def _set_signal_dict(self):
""" Set signal dictionary"""
self._signal_dict = {}
sig_list = []
for ct1, sc in enumerate(self._signalcollections):
for ct2, sig in enumerate(sc._signals):
if sig.name not in sig_list:
sig_list.append(sig.name)
self._signal_dict[sig.name] = sig
else:
msg = "Duplicate signal {} from objects {} and {}."
msg += "\nThis functionality was added in v1.1.0 and may"
msg += " cause post v1.1.0 functionality to break."
msg += "\nThis may not cause other errors but it is"
msg += " recommended that you use a custom name for one"
msg += " of the duplicate signals.\n"
logger.warn(msg.format(sig.name, sig, self._signal_dict[sig.name]))
@property
def signals(self):
""" Return signal dictionary."""
return self._signal_dict
[docs] def get_signal(self, name):
"""Returns ``Signal`` instance given the signal name."""
return self._signal_dict[name]
[docs] def summary(self, include_params=True, to_stdout=False):
"""generate summary string for PTA model
:param include_params: [bool]
list all parameters for each signal
:param to_stdout: [bool]
print summary to `stdout` instead of returning it
:return: [string]
"""
summary = ""
row = ["Signal Name", "Signal Class", "no. Parameters"]
summary += "{: <40} {: <30} {: <20}\n".format(*row)
summary += "=" * 90 + "\n"
cpcount, copcount = 0, 0
for sc in self._signalcollections:
for sig in sc._signals:
for p in sig.param_names:
if sc.psrname not in p:
cpcount += 1
row = [sig.name, sig.__class__.__name__, len(sig.param_names)]
summary += "{: <40} {: <30} {: <20}\n".format(*row)
if include_params:
summary += "\n"
summary += "params:\n"
for par in sig._params.values():
if isinstance(par, ConstantParameter):
copcount += 1
summary += "{!s: <90}\n".format(par.__repr__())
summary += "_" * 90 + "\n"
summary += "=" * 90 + "\n"
summary += "Total params: {}\n".format(len(self.param_names) + copcount)
summary += "Varying params: {}\n".format(len(self.param_names))
summary += "Common params: {}\n".format(cpcount)
summary += "Fixed params: {}\n".format(copcount)
summary += "Number of pulsars: {}\n".format(len(self._signalcollections))
if to_stdout:
logger.info(summary)
else:
return summary
[docs]def SignalCollection(metasignals):
"""Class factory for ``SignalCollection`` objects."""
@six.add_metaclass(MetaCollection)
class SignalCollection(object):
_metasignals = metasignals
def __init__(self, psr):
self.psrname = psr.name
# instantiate all the signals with a pulsar
self._signals = [metasignal(psr) for metasignal in self._metasignals]
self._residuals = psr.residuals
self._set_cache_parameters()
def __add__(self, other):
return PTA([self, other])
# TODO: this could be implemented more cleanly
def _set_cache_parameters(self):
""" Sets the cache for various signal types."""
self.white_params = []
self.basis_params = []
self.delay_params = []
for signal in self._signals:
if signal.signal_type == "white noise":
self.white_params.extend(signal.ndiag_params)
elif signal.signal_type in ["basis", "common basis"]:
# to support GP coefficients, and yet do the right thing
# for common GPs, which do not have coefficients yet
self.delay_params.extend(getattr(signal, "delay_params", []))
self.basis_params.extend(signal.basis_params)
elif signal.signal_type in ["deterministic"]:
self.delay_params.extend(signal.delay_params)
else:
msg = "{} signal type not recognized! Caching ".format(signal.signal_type)
msg += "may not work correctly for this signal."
logger.error(msg)
# a candidate for memoization
@property
def params(self):
return sorted({param for signal in self._signals for param in signal.params}, key=lambda par: par.name)
@property
def param_names(self):
ret = []
for p in self.params:
if p.size:
for ii in range(0, p.size):
ret.append(p.name + "_{}".format(ii))
else:
ret.append(p.name)
return ret
@property
def signals(self):
return self._signals
def set_default_params(self, params):
for signal in self._signals:
signal.set_default_params(params)
def _combine_basis_columns(self, signals):
"""Given a set of Signal objects, each of which may return an
Fmat (through get_basis()), return a dict (indexed by signal)
of integer arrays that map individual Fmat columns to the
combined Fmat.
Note: The Fmat returned here is simply meant to initialize the
matrix to save computations when calling `get_basis` later.
"""
idx, Fmatlist, hashlist = {}, [], []
cc = 0
for signal in signals:
Fmat = signal.get_basis()
if Fmat is not None and not signal.basis_params:
idx[signal] = []
for i, column in enumerate(Fmat.T):
colhash = hash(column.tostring())
if signal.basis_combine and colhash in hashlist:
j = hashlist.index(colhash)
idx[signal].append(j)
else:
idx[signal].append(cc)
Fmatlist.append(column)
hashlist.append(colhash)
cc += 1
elif Fmat is not None and signal.basis_params:
nf = Fmat.shape[1]
idx[signal] = list(range(cc, cc + nf))
cc += nf
if not idx:
return {}, None
else:
ncol = len(np.unique(sum(idx.values(), [])))
nrow = len(Fmatlist[0])
return ({key: np.array(idx[key]) for key in idx.keys()}, np.zeros((nrow, ncol)))
# goofy way to cache _idx
def __getattr__(self, par):
if par in ("_idx", "_Fmat"):
self._idx, self._Fmat = self._combine_basis_columns(self._signals)
return getattr(self, par)
else:
raise AttributeError("{} object has no attribute {}".format(self.__class__, par))
@cache_call("white_params")
def get_ndiag(self, params):
ndiags = [signal.get_ndiag(params) for signal in self._signals]
return sum(ndiag for ndiag in ndiags if ndiag is not None)
@cache_call("delay_params")
def get_delay(self, params):
delays = [signal.get_delay(params) for signal in self._signals]
return sum(delay for delay in delays if delay is not None)
@cache_call("delay_params")
def get_detres(self, params):
return self._residuals - self.get_delay(params)
@cache_call("basis_params")
def get_basis(self, params={}):
for signal in self._signals:
if signal in self._idx:
self._Fmat[:, self._idx[signal]] = signal.get_basis(params)
return self._Fmat
def get_phiinv(self, params):
return self.get_phi(params).inv()
# returns a KernelMatrix object
def get_phi(self, params):
if self._Fmat is None:
return None
phi = KernelMatrix(self._Fmat.shape[1])
for signal in self._signals:
if signal in self._idx:
phi = phi.add(signal.get_phi(params), self._idx[signal])
return phi
@cache_call(["basis_params", "white_params", "delay_params"])
def get_TNr(self, params):
T = self.get_basis(params)
if T is None:
return None
Nvec = self.get_ndiag(params)
res = self.get_detres(params)
return Nvec.solve(res, left_array=T)
@cache_call(["basis_params", "white_params"])
def get_TNT(self, params):
T = self.get_basis(params)
if T is None:
return None
Nvec = self.get_ndiag(params)
return Nvec.solve(T, left_array=T)
@cache_call(["white_params", "delay_params"])
def get_rNr_logdet(self, params):
Nvec = self.get_ndiag(params)
res = self.get_detres(params)
return Nvec.solve(res, left_array=res, logdet=True)
# TO DO: cache how?
def get_logsignalprior(self, params):
return sum(signal.get_logsignalprior(params) for signal in self._signals)
return SignalCollection
[docs]def cache_call(attrs, limit=2):
"""This decorator caches the output of a class method that takes
a single parameter 'params'. It saves the cache in the instance
attributes _cache_<methodname> and _cache_list_<methodname>.
The cache keys are listed in the class attribute (or attributes)
specified in the initial decorator call. For instance, if
the decorator is applied as @cache_call('basis_params'), then
the parameters listed in self.basis_params (together with their values)
will be used as the key.
The parameter 'limit' specifies the number of entries saved
in the cache."""
# convert to list of lists if only one attribute used
if not isinstance(attrs, list):
attrs = [attrs]
def cache_decorator(func):
def wrapper(self, params={}):
# get the relevant parameters to be cached
keys = sum([getattr(self, attr) for attr in attrs], [])
ret = []
# TODO: this deals with vector parameters but could be cleaner...
for key in keys:
if key in params:
if np.ndim(params[key]) > 0:
ret.append((key, tuple(params[key])))
else:
ret.append((key, params[key]))
key = tuple(ret)
# key = tuple([(key, params[key]) for key in keys if key in params])
# make sure the cache is part of the object
if not hasattr(self, "_cache_" + func.__name__):
msg = "Create cache {} for signal {}".format(func.__name__, self.__class__)
logger.debug(msg)
setattr(self, "_cache_" + func.__name__, {})
setattr(self, "_cache_list_" + func.__name__, [])
cache = getattr(self, "_cache_" + func.__name__)
cache_list = getattr(self, "_cache_list_" + func.__name__)
if key not in cache:
msg = "Setting cache for {} in {}: {}".format(attrs, self.__class__, key)
logger.debug(msg)
cache_list.append(key)
cache[key] = func(self, params)
if len(cache_list) > limit:
_ = cache.pop(cache_list.pop(0), None) # noqa: F841
return cache[key]
return wrapper
return cache_decorator
[docs]class csc_matrix_alt(sps.csc_matrix):
"""Sub-class of ``scipy.sparse.csc_matrix`` with custom ``add`` and
``solve`` methods.
"""
def _add_diag(self, other):
other_diag = sps.dia_matrix((other, np.array([0])), shape=(other.shape[0], other.shape[0]))
return self._binopt(other_diag, "_plus_")
def __add__(self, other):
if isinstance(other, (np.ndarray, ndarray_alt)) and other.ndim == 1:
return self._add_diag(other)
else:
return super(csc_matrix_alt, self).__add__(other)
# hacky way to fix adding ndarray on left
def __radd__(self, other):
if isinstance(other, (np.ndarray, ndarray_alt)) or other == 0:
return self.__add__(other)
else:
raise TypeError
[docs] def solve(self, other, left_array=None, logdet=False):
cf = cholesky(self)
mult = cf(other)
if left_array is not None:
mult = np.dot(left_array.T, mult)
ret = (mult, cf.logdet()) if logdet else mult
return ret
[docs]class ndarray_alt(np.ndarray):
"""Sub-class of ``np.ndarray`` with custom ``solve`` method."""
def __new__(cls, inputarr):
obj = np.asarray(inputarr).view(cls)
return obj
def __add__(self, other):
try:
ret = super(ndarray_alt, self).__add__(other)
except:
ret = other + self
return ret
[docs] def solve(self, other, left_array=None, logdet=False):
if other.ndim == 1:
mult = np.array(other / self)
elif other.ndim == 2:
mult = np.array(other / self[:, None])
if left_array is not None:
mult = np.dot(left_array.T, mult)
ret = (mult, float(np.sum(np.log(self)))) if logdet else mult
return ret
[docs]class BlockMatrix(object):
def __init__(self, blocks, slices, nvec=0):
self._blocks = blocks
self._slices = slices
self._nvec = nvec
if np.any(nvec != 0):
s1 = set(np.arange(len(nvec)))
s2 = set(np.concatenate([np.arange(len(nvec))[slc] for slc in slices]))
sd = s1.difference(s2)
self._idx = np.array([s for s in sd])
def __add__(self, other):
nvec = self._nvec + other
return BlockMatrix(self._blocks, self._slices, nvec)
# hacky way to fix adding 0
def __radd__(self, other):
if other == 0:
return self.__add__(other)
else:
raise TypeError
def _solve_ZNX(self, X, Z):
"""Solves :math:`Z^T N^{-1}X`, where :math:`X`
and :math:`Z` are 1-d or 2-d arrays.
"""
if X.ndim == 1:
X = X.reshape(X.shape[0], 1)
if Z.ndim == 1:
Z = Z.reshape(Z.shape[0], 1)
n, m = Z.shape[1], X.shape[1]
ZNX = np.zeros((n, m))
if len(self._idx) > 0:
ZNXr = np.dot(Z[self._idx, :].T, X[self._idx, :] / self._nvec[self._idx, None])
else:
ZNXr = 0
for slc, block in zip(self._slices, self._blocks):
Zblock = Z[slc, :]
Xblock = X[slc, :]
if slc.stop - slc.start > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[slc]))
bx = sl.cho_solve(cf, Xblock)
else:
bx = Xblock / self._nvec[slc][:, None]
ZNX += np.dot(Zblock.T, bx)
ZNX += ZNXr
return ZNX.squeeze() if len(ZNX) > 1 else float(ZNX)
def _solve_NX(self, X):
"""Solves :math:`N^{-1}X`, where :math:`X`
is a 1-d or 2-d array.
"""
if X.ndim == 1:
X = X.reshape(X.shape[0], 1)
NX = X / self._nvec[:, None]
for slc, block in zip(self._slices, self._blocks):
Xblock = X[slc, :]
if slc.stop - slc.start > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[slc]))
NX[slc] = sl.cho_solve(cf, Xblock)
return NX.squeeze()
def _get_logdet(self):
"""Returns log determinant of :math:`N+UJU^{T}` where :math:`U`
is a quantization matrix.
"""
if len(self._idx) > 0:
logdet = np.sum(np.log(self._nvec[self._idx]))
else:
logdet = 0
for slc, block in zip(self._slices, self._blocks):
if slc.stop - slc.start > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[slc]))
logdet += np.sum(2 * np.log(np.diag(cf[0])))
else:
logdet += np.sum(np.log(self._nvec[slc]))
return logdet
[docs] def solve(self, other, left_array=None, logdet=False):
if other.ndim not in [1, 2]:
raise TypeError
if left_array is not None:
if left_array.ndim not in [1, 2]:
raise TypeError
if left_array is not None:
ret = self._solve_ZNX(other, left_array)
else:
ret = self._solve_NX(other)
return (ret, self._get_logdet()) if logdet else ret
[docs]class ShermanMorrison(object):
"""Custom container class for Sherman-morrison array inversion."""
def __init__(self, jvec, slices, nvec=0.0):
self._jvec = jvec
self._slices = slices
self._nvec = nvec
def __add__(self, other):
nvec = self._nvec + other
return ShermanMorrison(self._jvec, self._slices, nvec)
# hacky way to fix adding 0
def __radd__(self, other):
if other == 0:
return self.__add__(other)
else:
raise TypeError
def _solve_D1(self, x):
"""Solves :math:`N^{-1}x` where :math:`x` is a vector."""
Nx = x / self._nvec
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
rblock = x[slc]
niblock = 1 / self._nvec[slc]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
Nx[slc] -= beta * np.dot(niblock, rblock) * niblock
return Nx
def _solve_1D1(self, x, y):
"""Solves :math:`y^T N^{-1}x`, where :math:`x` and
:math:`y` are vectors.
"""
Nx = x / self._nvec
yNx = np.dot(y, Nx)
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
xblock = x[slc]
yblock = y[slc]
niblock = 1 / self._nvec[slc]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
yNx -= beta * np.dot(niblock, xblock) * np.dot(niblock, yblock)
return yNx
def _solve_2D2(self, X, Z):
"""Solves :math:`Z^T N^{-1}X`, where :math:`X`
and :math:`Z` are 2-d arrays.
"""
ZNX = np.dot(Z.T / self._nvec, X)
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
Zblock = Z[slc, :]
Xblock = X[slc, :]
niblock = 1 / self._nvec[slc]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
zn = np.dot(niblock, Zblock)
xn = np.dot(niblock, Xblock)
ZNX -= beta * np.outer(zn.T, xn)
return ZNX
def _get_logdet(self):
"""Returns log determinant of :math:`N+UJU^{T}` where :math:`U`
is a quantization matrix.
"""
logdet = np.einsum("i->", np.log(self._nvec))
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
niblock = 1 / self._nvec[slc]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
logdet += np.log(jv) - np.log(beta)
return logdet
[docs] def solve(self, other, left_array=None, logdet=False):
if other.ndim == 1:
if left_array is None:
ret = self._solve_D1(other)
elif left_array is not None and left_array.ndim == 1:
ret = self._solve_1D1(other, left_array)
elif left_array is not None and left_array.ndim == 2:
ret = np.dot(left_array.T, self._solve_D1(other))
else:
raise TypeError
elif other.ndim == 2:
if left_array is None:
raise TypeError
elif left_array is not None and left_array.ndim == 2:
ret = self._solve_2D2(other, left_array)
elif left_array is not None and left_array.ndim == 1:
ret = np.dot(other.T, self._solve_D1(left_array))
else:
raise TypeError
else:
raise TypeError
return (ret, self._get_logdet()) if logdet else ret