From 2bfb8dcfb452a594c03c9278e86c4f30139a094a Mon Sep 17 00:00:00 2001 From: Victor Boussange Date: Tue, 14 Nov 2023 18:47:28 +0100 Subject: [PATCH] fixed bugs after v15 of ComponentArrays --- Project.toml | 2 +- src/inference.jl | 5 ++--- src/utils.jl | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index ee6687f..385832d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PiecewiseInference" uuid = "27a201f4-b6a1-4745-b96e-0c27845dca54" authors = ["Victor "] -version = "0.9.7" +version = "0.9.8" [deps] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" diff --git a/src/inference.jl b/src/inference.jl index ac36617..49f6663 100644 --- a/src/inference.jl +++ b/src/inference.jl @@ -91,13 +91,12 @@ through the optimizers `optimizers`. Returns a `InferenceResult`. - `optimizers` : array of optimizers, e.g. `[Adam(0.01)]` - `epochs` : A vector with number of epochs for each optimizer in `optimizers`. - `batchsizes`: An vector of batch sizes, which should match the length of - `optimizers`. + `optimizers`. If nothing is provided, all segments are used at once (full batch). - `verbose_loss` : Whether to display loss during training. - `info_per_its = 50`: The frequency at which to display the training information. - `plotting` : Whether to plot the convergence loss during training. -- `cb` : A call back function. Must be of the form `cb(θs, p_trained, losses, - pred, ranges)`. +- `cb` : A call back function. Must be of the form `cb(p_trained, losses, pred, ranges)`. - `threshold` : The tolerance for stopping training. - `save_pred = true`: Whether to save the predictions. - `save_losses = true` : Whether to save the losses. diff --git a/src/utils.jl b/src/utils.jl index 0168b12..01d1e27 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -217,7 +217,7 @@ end function to_optim_space(p::ComponentArray, infprob::InferenceProblem) @unpack p0, p_bij = infprob - pairs = [reshape(p_bij[k](p[k]),:) for k in keys(p0)] + pairs = [reshape(p_bij[k](getproperty(p,k)),:) for k in keys(p0)] ax = getaxes(p0) return ComponentArray(vcat(pairs...), ax) end @@ -225,7 +225,7 @@ end # TODO /!\ order is not guaranteed! function to_param_space(θ::ComponentArray, infprob::InferenceProblem) @unpack p0, p_bij = infprob - pairs = [reshape(inverse(p_bij[k])(θ[k]),:) for k in keys(p0)] + pairs = [reshape(inverse(p_bij[k])(getproperty(θ,k)),:) for k in keys(p0)] ax = getaxes(p0) return ComponentArray(vcat(pairs...), ax) end \ No newline at end of file