An implementation of the Glow generative model in jax
, and using the high-level API flax
. Glow is a reversible generative model, based on the variational auto-encoder framework with normalizing flows.
The notebook can also be found on kaggle, where it was trained on a subset of the aligned CelebA dataset.
pip install jax jaxlib
pip install flax
Random samples can be generated as follows; Here for instance for generating 16 samples with sampling temperature 0.7 and setting the random seed to 0:
python3 sample.py 16 -t 0.7 -s 0 --model_path [path]
A pretrained model can be found in the kaggle notebook's outputs.
Note: The model was only trained for roughly 13 epochs due to computation limits. Compared to the original model, it also uses ashallower flow (K = 16
flow steps per scale)