Skip to content

Commit

Permalink
fixed bugs after v15 of ComponentArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
vboussange committed Nov 14, 2023
1 parent b18a209 commit 2bfb8dc
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PiecewiseInference"
uuid = "27a201f4-b6a1-4745-b96e-0c27845dca54"
authors = ["Victor <bvictor@ethz.ch>"]
version = "0.9.7"
version = "0.9.8"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand Down
5 changes: 2 additions & 3 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,12 @@ through the optimizers `optimizers`. Returns a `InferenceResult`.
- `optimizers` : array of optimizers, e.g. `[Adam(0.01)]`
- `epochs` : A vector with number of epochs for each optimizer in `optimizers`.
- `batchsizes`: An vector of batch sizes, which should match the length of
`optimizers`.
`optimizers`. If nothing is provided, all segments are used at once (full batch).
- `verbose_loss` : Whether to display loss during training.
- `info_per_its = 50`: The frequency at which to display the training
information.
- `plotting` : Whether to plot the convergence loss during training.
- `cb` : A call back function. Must be of the form `cb(θs, p_trained, losses,
pred, ranges)`.
- `cb` : A call back function. Must be of the form `cb(p_trained, losses, pred, ranges)`.
- `threshold` : The tolerance for stopping training.
- `save_pred = true`: Whether to save the predictions.
- `save_losses = true` : Whether to save the losses.
Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,15 @@ end

function to_optim_space(p::ComponentArray, infprob::InferenceProblem)
@unpack p0, p_bij = infprob
pairs = [reshape(p_bij[k](p[k]),:) for k in keys(p0)]
pairs = [reshape(p_bij[k](getproperty(p,k)),:) for k in keys(p0)]
ax = getaxes(p0)
return ComponentArray(vcat(pairs...), ax)
end

# TODO /!\ order is not guaranteed!
function to_param_space::ComponentArray, infprob::InferenceProblem)
@unpack p0, p_bij = infprob
pairs = [reshape(inverse(p_bij[k])(θ[k]),:) for k in keys(p0)]
pairs = [reshape(inverse(p_bij[k])(getproperty(θ,k)),:) for k in keys(p0)]
ax = getaxes(p0)
return ComponentArray(vcat(pairs...), ax)
end

0 comments on commit 2bfb8dc

Please sign in to comment.