Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
mraveri committed Sep 26, 2024
1 parent 646bdbc commit dfed9eb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tensiometer/mcmc_tension/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def estimate_shift(flow, prior_flow=None, tol=0.05, max_iter=1000, step=100000):
counter = max_iter

# define threshold for tension calculation:
_thres = flow.log_probability(flow.cast(np.zeros(flow.num_params)))
_thres = flow.log_probability(flow.cast([np.zeros(flow.num_params)]))[0]
if prior_flow is not None:
_thres = _thres - prior_flow.log_probability(prior_flow.cast(np.zeros(prior_flow.num_params)))
_thres = _thres - prior_flow.log_probability(prior_flow.cast([np.zeros(prior_flow.num_params)]))[0]

_num_filtered = 0
_num_samples = 0
Expand Down Expand Up @@ -89,9 +89,9 @@ def estimate_shift_from_samples(flow, prior_flow=None):
"""

# define threshold for tension calculation:
_thres = flow.log_probability(flow.cast(np.zeros(flow.num_params)))
_thres = flow.log_probability(flow.cast([np.zeros(flow.num_params)]))[0]
if prior_flow is not None:
_thres = _thres - prior_flow.log_probability(prior_flow.cast(np.zeros(prior_flow.num_params)))
_thres = _thres - prior_flow.log_probability(prior_flow.cast([np.zeros(prior_flow.num_params)]))[0]

# calculate probability on the samples:
_s_prob = flow.log_probability(flow.cast(flow.chain_samples))
Expand Down
1 change: 1 addition & 0 deletions tensiometer/synthetic_probability/synthetic_probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -2256,6 +2256,7 @@ def __init__(self, flows, **kwargs):
'param_names',
'param_labels',
'parameter_ranges',
'periodic_params',
'chain_samples',
'chain_loglikes',
'has_loglikes',
Expand Down

0 comments on commit dfed9eb

Please sign in to comment.