Skip to content

An implementation of the Glow generative model in jax and flax

License

Notifications You must be signed in to change notification settings

ameroyer/glow_jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Glow generative model in jax

An implementation of the Glow generative model in jax, and using the high-level API flax. Glow is a reversible generative model, based on the variational auto-encoder framework with normalizing flows. The notebook can also be found on kaggle, where it was trained on a subset of the aligned CelebA dataset.

Setup

Dependencies

pip install jax jaxlib
pip install flax

Sample from the model

Random samples can be generated as follows; Here for instance for generating 16 samples with sampling temperature 0.7 and setting the random seed to 0:

python3 sample.py 16 -t 0.7 -s 0 --model_path [path]

Example

A pretrained model can be found in the kaggle notebook's outputs.

Note: The model was only trained for roughly 13 epochs due to computation limits. Compared to the original model, it also uses ashallower flow (K = 16 flow steps per scale)

Example results - training evolution

t=0.85

Example results - sampling

t=0.85

t=0.7

Example results - linear interpolation

Linear interpolation results

About

An implementation of the Glow generative model in jax and flax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published