# parameter.py
"""Contains parameter types for use in `enterprise` ``Signal`` classes."""
from __future__ import (absolute_import, division,
print_function, unicode_literals)
import inspect
import functools
import numpy as np
import scipy.stats
from enterprise.signals.selections import selection_func
[docs]def sample(parlist):
"""Sample a list of Parameters consistently (i.e., keeping
track of hyperparameters)."""
# we'll be nice and accept a single parameter
parlist = [parlist] if isinstance(parlist, Parameter) else parlist
ret = {}
_sample(parlist, ret)
return ret
def _sample(parlist, parvalues):
"""Recursive function used by sample()."""
for par in parlist:
if par not in parvalues:
parvalues.update(sample(par.params[1:]))
parvalues[par.name] = par.sample(params=parvalues)
[docs]class Parameter(object):
# instances will need to define _size, _prior, and _typename
# thus this class is technically abstract
def __init__(self, name):
self.name = name
self.prior = self._prior(name)
[docs] def get_logpdf(self, value=None, **kwargs):
if value is None and 'params' in kwargs:
value = kwargs['params'][self.name]
logpdf = np.log(self.prior(value, **kwargs))
return logpdf if self._size is None else np.sum(logpdf)
[docs] def get_pdf(self, value=None, **kwargs):
if value is None and 'params' in kwargs:
value = kwargs['params'][self.name]
pdf = self.prior(value, **kwargs)
return pdf if self._size is None else np.prod(pdf)
[docs] def sample(self, **kwargs):
if self._sampler is None:
raise AttributeError("No sampler was provided for this Parameter.")
else:
if self.name in kwargs:
raise ValueError(
"You shouldn't give me my value when you're sampling me.!")
return self.prior(func=self._sampler, size=self._size, **kwargs)
@property
def size(self):
return self._size
@property
def params(self):
return [self] + [par for par in self.prior.params
if not isinstance(par, ConstantParameter)]
def __repr__(self):
typename = self._typename.format(**self.prior._params)
array = '' if self._size is None else '[{}]'.format(self._size)
return '"{}":{}{}'.format(self.name, typename, array)
# this trick lets us pass an instantiated parameter to a signal;
# the parameter will refuse to be renamed and will return itself
def __call__(self, name):
return self
[docs]def UserParameter(prior, sampler=None, size=None):
"""Class factory for UserParameter, with `prior` given as an Enterprise
Function (one argument, the value; arbitrary keyword arguments, which
become hyperparameters). Optionally, `sampler` can be given as a regular
(not Enterprise function), taking the same keyword parameters as `prior`.
"""
class UserParameter(Parameter):
_size = size
_prior = prior
_sampler = staticmethod(sampler)
_typename = 'UserParameter'
return UserParameter
def _argrepr(typename, **kwargs):
args = []
for par, arg in kwargs.items():
if type(arg) == type and issubclass(arg, Parameter):
args.append('{}="{{{}.name}}"'.format(par, par))
elif isinstance(arg, Parameter):
args.append('{}={}'.format(par, arg.name))
else:
args.append('{}={}'.format(par, arg))
return '{}({})'.format(typename,','.join(args))
# note: will not do a jointly normal prior
[docs]def NormalPrior(value, mu, sigma):
"""Prior function for Normal parameters. Note that `sigma` can be a
scalar for a 1-d distribution, a vector for multivariate distribution that
uses the vector as the sqrt of the diagonal of the covaraince matrix,
or a matrix which is the covariance."""
cov = sigma if np.ndim(sigma) == 2 else sigma**2
return scipy.stats.multivariate_normal.pdf(value, mean=mu, cov=cov)
[docs]def NormalSampler(mu, sigma, size=None):
"""Sampling function for Normal parameters."""
cov = sigma if np.ndim(sigma) == 2 else sigma**2
return scipy.stats.multivariate_normal.rvs(
mean=mu, cov=cov, size=size)
[docs]def Normal(mu=0, sigma=1, size=None):
"""Class factory for Normal parameters."""
class Normal(Parameter):
_size = size
_prior = Function(NormalPrior, mu=mu, sigma=sigma)
_sampler = staticmethod(NormalSampler)
_typename = _argrepr('Normal', mu=mu, sigma=sigma)
return Normal
[docs]def LinearExpPrior(value, pmin, pmax):
"""Prior function for LinearExp parameters."""
if pmin >= pmax:
raise ValueError("LinearExp Parameter requires pmin < pmax.")
return (((pmin <= value) & (value <= pmax)) * np.log(10) *
10**value / (10**pmax - 10**pmin))
[docs]def LinearExpSampler(pmin, pmax, size):
"""Sampling function for LinearExp parameters."""
if pmin >= pmax:
raise ValueError("LinearExp Parameter requires pmin < pmax.")
return np.log10(np.random.uniform(10**pmin, 10**pmax, size))
[docs]def LinearExp(pmin, pmax, size=None):
"""Class factory for LinearExp parameters (with pdf(x) ~ 10^x)."""
class LinearExp(Parameter):
_size = size
_prior = Function(LinearExpPrior, pmin=pmin, pmax=pmax)
_sampler = staticmethod(LinearExpSampler)
_typename = _argrepr('LinearExp', pmin=pmin, pmax=pmax)
return LinearExp
[docs]class ConstantParameter(object):
"""Constant Parameter base class."""
def __init__(self, name):
self.name = name
@property
def value(self):
return self.value
@value.setter
def value(self, value):
self.value = value
def __call__(self, name):
return self
def __repr__(self):
return '"{}":Constant={}'.format(self.name, self.value)
[docs]def Constant(val=None):
class Constant(ConstantParameter):
value = val
return Constant
[docs]def Function(func, name='', **func_kwargs):
fname = name
class Function(object):
def __init__(self, name, psr=None):
self._func = selection_func(func)
self._psr = psr
self._params = {}
self._defaults = {}
# divide keyword parameters into those that are Parameter classes,
# Parameter instances (useful for global parameters),
# and something else (which we will assume is a value)
for kw, arg in func_kwargs.items():
if isinstance(arg, type) and issubclass(
arg, (Parameter, ConstantParameter)):
# parameter name template
# pname_[signalname_][fname_]parname
pnames = [name, fname, kw]
par = arg('_'.join([n for n in pnames if n]))
self._params[kw] = par
elif isinstance(arg, (Parameter, ConstantParameter)):
self._params[kw] = arg
else:
self._defaults[kw] = arg
def __call__(self, *args, **kwargs):
# order of parameter resolution:
# - parameter given in kwargs
# - named sampling parameter in self._params, if given in params
# or if it has a value
# - parameter given as constant in Function definition
# - default value for keyword parameter in func definition
# trick to get positional arguments before params kwarg
params = kwargs.get('params',{})
if 'params' in kwargs:
del kwargs['params']
# allow calling an alternate function with the same parameters
func = kwargs.get('func',self._func)
if 'func' in kwargs:
del kwargs['func']
for kw, arg in func_kwargs.items():
if kw not in kwargs and kw in self._params:
par = self._params[kw]
if par.name in params:
kwargs[kw] = params[par.name]
elif hasattr(par, 'value'):
kwargs[kw] = par.value
for kw, arg in self._defaults.items():
if kw not in kwargs:
kwargs[kw] = arg
if self._psr is not None and 'psr' not in kwargs:
kwargs['psr'] = self._psr
return func(*args, **kwargs)
def add_kwarg(self, **kwargs):
self._defaults.update(kwargs)
@property
def params(self):
# if we extract the ConstantParameter value above, we would not
# need a special case here
return sum([par.params for par in self._params.values()
if not isinstance(par, ConstantParameter)], [])
return Function
[docs]def get_funcargs(func):
"""Convenience function to get args and kwargs of any function."""
argspec = inspect.getargspec(func)
if argspec.defaults is None:
args = argspec.args
kwargs = []
else:
args = argspec.args[:(len(argspec.args)-len(argspec.defaults))]
kwargs = argspec.args[-len(argspec.defaults):]
return args, kwargs
[docs]def function(func):
"""Decorator for Function."""
funcargs, _ = get_funcargs(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
fargs = {funcargs[ct]: val for ct, val in
enumerate(args[:len(funcargs)])}
fargs.update(kwargs)
if not np.all([fa in fargs.keys() for fa in funcargs]):
return Function(func, **kwargs)
for kw, arg in kwargs.items():
if ((isinstance(arg, type) and issubclass(
arg, (Parameter, ConstantParameter))) or isinstance(
arg, (Parameter, ConstantParameter))):
return Function(func, **kwargs)
return func(*args, **kwargs)
return wrapper