A toolbox to compute the robustness of STL formulas using computation graphs. This is the jax version of the STLCG toolbox originally implemented in PyTorch.
Requires Python 3.10+
Install the repo:
pip install git+https://github.com/UW-CTRL/stljax.git
Alternatively, if you like to install the package in editable mode,
git clone https://github.com/UW-CTRL/stljax.git
cd stljax
pip install -e .
(Best to use a virtual environment.)
demo.ipynb
is an IPython jupyter notebook that showcases the basic functionality of the toolbox:
- Setting up signals for the formulas, including the use of Expressions and Predicates
- Defining STL formulas and visualizing them
- Evaluating STL robustness, and robustness trace
stljax leverages to benefits of jax and automatic differentiation!
Aside from using jax as the backend, stljax is more recent and tidier implementation of stlcg which was originally implemented in PyTorch back ~2019.
- Removed the
distributed_mean
hack from original stlcg implementation. jax keeps track of multiple max/min values and will distribute the gradients across all max/min values!
Tags 🏷️ | Description |
---|---|
v.1.1.0 | General code improvements. Included recurrent implementation and example notebooks. |
v.1.0.0 | Removed awkward expected signal dimension & leverage vmap for batched inputs. Masking for temporal operations & remove need to reverse signals. |
v0.0.0 | A transfer from the 2019 PyTorch implementation to Jax + some tidying + adding Predicates + reversing signal automatically. |
There are two ways to define an STL formula. Using either the Expression
and Predicate
classes.
With Expression
, you are essentially defining a signal whose values are the output of a predicate function computed external to the STL robustness computation formula.
Essentially, you process your desired signal first (e.g., from a state trajectory, you compute velocity), and then you pass it directly into the STL formula.
A step-by-step break down:
-
Suppose you have a
trajectory
that is an array of size[time_steps, state_dim]
-
Suppose we have a
get_velocity()
function and aget_acceleration()
function:
velocity_value = get_velocity(trajectory) # [time_steps]
acceleration_value = get_acceleration(trajectory) # [time_steps]
-
Then, we can define the following two
Expression
objects:
velocity_exp = Expression("velocity," value=velocity_value)
acceleration_exp = Expression("acceleration", value=acceleration_value)
-
With these two expressions, we can define an STL formula
ϕ = □ (velocity_exp > 5.0) ∨ ◊ (acceleration_exp > 5.0)
which is equivalent toϕ = Always(velocity_exp > 5.0) & Eventually(acceleration_exp > 5.0)
. -
To compute the robustness trace of
ϕ
, we runϕ((velocity_exp, acceleration_exp))
where the input is a tuple since the first part of the formula depends on velocity, and the second part depends on acceleration.
This means that the user needs to compute velocity and acceleration values before calling ϕ
to compute the robustness trace (or ϕ.robustness((velocity_exp, acceleration_exp))
for the robustness value)
NOTE: Expressions are used to define an STL formula. While you can, you don't necessarily need to use Expressions as inputs for computing robustness values. So ϕ((velocity_value, acceleration_value))
should also work.
With Predicate
, this is more true to the STL definition. You pass a predicate function when defining an STL formula rather than passing the signal that would be the output of a predicate function.
Essentially, you pass your N-D input (e.g., state trajectory) directly into the formula when computing robustness values.
A step-by-step break down:
-
Suppose you have a
trajectory
that is an array of size[time_steps, state_dim]
-
Suppose we have a
get_velocity()
function and aget_acceleration()
function:
velocity_value = get_velocity(trajectory) # [time_steps]
acceleration_value = get_acceleration(trajectory) # [time_steps]
-
Then, we can define the following two
Predicate
objects:
velocity_pred = Predicate("velocity", predicate_function=get_velocity)
acceleration_pred = Predicate("acceleration", predicate_function=get_acceleration)
-
With these two
Predicate
objects, we can define an STL formulaϕ = □ (velocity_pred > 5.0) ∨ ◊ (acceleration_pred > 5.0)
which is equivalent toϕ = Always(velocity_pred > 5.0) & Eventually(acceleration_pred > 5.0)
. -
To compute the robustness trace of
ϕ
, we runϕ(trajectory)
where the input is what all the predicate functions expect the input to be.
In summary:
When using Predicates to define STL formulas, it will extract the velocity and acceleration values inside the robustness computation. Whereas when using Expressions, you need to extract the velocity and acceleration outside of the robustness computation.
We can use jax.vmap
to handle multiple signals at once.
jax.vmap(formula)(signals) # signals is shape [bs, time_dim,...]
NOTE: Need to take care for formulas defined with Expressions and need multiple inputs. Need a wrapper since jax.vmap
doesn't like tuples in a single argument.
- manage reversing of signals internally for recurrent cases.
Here are a list of publications that use stlcg/stljax. Please file an issue, or pull request to add your publication to the list.
K. Leung, and M. Pavone, "Semi-Supervised Trajectory-Feedback Controller Synthesis for Signal Temporal Logic Specifications," in American Control Conference, 2022.
K. Leung, N. Aréchiga, and M. Pavone, "Backpropagation through STL specifications: Infusing logical structure into gradient-based methods," International Journal of Robotics Research, 2022.
J. DeCastro, K. Leung, N. Aréchiga, and M. Pavone, "Interpretable Policies from Formally-Specified Temporal Properties," in Proc. IEEE Int. Conf. on Intelligent Transportation Systems, Rhodes, Greece, 2020.
K. Leung, N. Arechiga, and M. Pavone, "Backpropagation for Parametric STL," in IEEE Intelligent Vehicles Symposium: Workshop on Unsupervised Learning for Automated Driving, Paris, France, 2019.
When citing stlcg/stljax, please use the following Bibtex:
# journal paper
@Article{LeungArechigaEtAl2020,
author = {Leung, K. and Ar\'{e}chiga, N. and Pavone, M.},
title = {Backpropagation through signal temporal logic specifications: Infusing logical structure into gradient-based methods},
booktitle = {{Int. Journal of Robotics Research}},
year = {2022},
}
# conference paper
@Inproceedings{LeungArechigaEtAl2020,
author = {Leung, K. and Ar\'{e}chiga, N. and Pavone, M.},
title = {Backpropagation through signal temporal logic specifications: Infusing logical structure into gradient-based methods},
booktitle = {{Workshop on Algorithmic Foundations of Robotics}},
year = {2020},
}
If there are any issues with the code, please make file an issue, or make a pull request.