-
Notifications
You must be signed in to change notification settings - Fork 9
INN.Nonlinear
Zhang Yanbo edited this page Oct 26, 2022
·
3 revisions
CLASS Nonlinear(dim, method='NICE', **kwargs)
[source]
A nonlinear INN layer for one-dimensional vector transformations.
Common parameters
-
dim
: dimension -
method
: This can be'NICE'
,'RealNVP'
and'ResFlow'
-
activation_fn
: Activation function for the coupling function. If the function is given, this argument will be ignored
RealNVP and NICE method
-
k
: number of hidden layer for coupling models; -
mask
: Mask for splitting input vectors
NICE
-
m
: Addition function, it should be a neural network maps vector with dimensiondim // 2
todim - dim // 2
. Ifm=None
, it will be generated automatically byINN.utils.default_net(dim, k, activation_fn)
;
RealNVP
-
f_log_s
: Multiplication function. It has the same dimension requirements asm
; -
f_t
: Addition function. It has the same dimension requirements asm
; -
clip
: (default:clip=1
) Clipping the output off_log_s
to avoid extreme numbers between[-clip, clip]
. The clipping is usingtanh
to keep the gradient; -
scale
: Scale for initialize the weights of coupling function. Large number ofscale
may lead toNaN
results due to the exponential number;
ResFlow
-
hidden
: Dimension of hidden layers -
n_hidden
: Number of hidden layers -
lipschitz_constrain
: Lipschitz constrain number, it should be lower than 1. Low value may decrease the computation power of the neural network; -
mem_efficient
: Using memory-efficient back-propagation if it isTrue
; -
est_steps
: Number of iterations for estimating gradients and Jacobians
Compute the transformed input y
. If compute_p=True
, return y
, log_p0
, logdet
.
The logdet
term is the log-determinate of the Jacobian matrix. This is essential for controlling the distribution of output.
Compute the inverse of y
. The **args
here is a placeholder for consistent format.
import torch
import INN
model = INN.Nonlinear(dim=4, method='RealNVP')
x = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
# Forward pass
y, log_p, log_det = model(x)
# Inverse pass
x_recon = model.inverse(y)
The outputs are:
# y
tensor([[0.9963, 2.0015, 3.0014, 3.9964],
[4.9938, 6.0042, 7.0035, 7.9923]], grad_fn=<AddBackward0>)
# x_recon
tensor([[1.0000, 2.0000, 3.0000, 4.0000],
[5.0000, 6.0000, 7.0000, 8.0000]], grad_fn=<AddBackward0>)