Skip to content

Commit

Permalink
Support KandinskyV3
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku committed Dec 26, 2023
1 parent ecc93f0 commit 8b8d7e7
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 21 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ For detailed user guides and advanced guides, please refer to our [Documentation
<td>
<ul>
<li><a href="configs/kandinsky_v22/README.md">Kandinsky 2.2 (2023)</a></li>
<li><a href="configs/kandinsky_v3/README.md">Kandinsky 3 (2023)</a></li>
</ul>
</td>
</tr>
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/datasets/pokemon_blip_kandinsky_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
dict(type="PackInputs"),
]
train_dataloader = dict(
batch_size=1,
batch_size=2,
num_workers=4,
dataset=dict(
type="HFDataset",
Expand Down
3 changes: 1 addition & 2 deletions configs/_base_/models/kandinsky_v3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
model = dict(
type="KandinskyV3",
model="kandinsky-community/kandinsky-3",
gradient_checkpointing=True)
model="kandinsky-community/kandinsky-3")
11 changes: 5 additions & 6 deletions configs/kandinsky_v3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
We present Kandinsky 3.0, a large-scale text-to-image generation model based on latent diffusion, continuing the series of text-to-image Kandinsky models and reflecting our progress to achieve higher quality and realism of image generation. Compared to previous versions of Kandinsky 2.x, Kandinsky 3.0 leverages a two times larger UNet backbone, a ten times larger text encoder and remove diffusion mapping. We describe the architecture of the model, the data collection procedure, the training technique, the production system of user interaction. We focus on the key components that, as we have identified as a result of a large number of experiments, had the most significant impact on improving the quality of our model in comparison with the other ones. By results of our side by side comparisons Kandinsky become better in text understanding and works better on specific domains.

<div align=center>
<img src=""/>
<img src="https://github.com/okotaku/diffengine/assets/24734142/2d670f44-9fa1-4095-be96-a82c91c9590b"/>
</div>

## Citation
Expand Down Expand Up @@ -46,24 +46,23 @@ Before inferencing, we should convert weights for diffusers format,
```bash
$ mim run diffengine publish_model2diffusers ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS}
# Example
$ mim run diffengine publish_model2diffusers configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py work_dirs/kandinsky_v3_pokemon_blip/epoch_50.pth work_dirs/kandinsky_v3_pokemon_blip --save-keys unet
# Note that when training colossalai, use `--colossalai` and set `INPUT_FILENAME` to index file.
$ mim run diffengine publish_model2diffusers configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py work_dirs/kandinsky_v3_pokemon_blip/epoch_50.pth/model/pytorch_model.bin.index.json work_dirs/kandinsky_v3_pokemon_blip --save-keys unet --colossalai
```

Then we can run inference.

```py
import torch
from diffusers import AutoPipelineForText2Image, Kandinsky3UNet

prompt = 'yoda pokemon'
checkpoint = 'work_dirs/kandinsky_v3_pokemon_blip'

unet = Kandinsky3UNet.from_pretrained(
checkpoint, subfolder='unet', torch_dtype=torch.float16)
checkpoint, subfolder='unet')
pipe = AutoPipelineForText2Image.from_pretrained(
"kandinsky-community/kandinsky-3",
unet=unet,
torch_dtype=torch.float16,
variant="fp16",
)
pipe.to('cuda')
Expand All @@ -81,4 +80,4 @@ image.save('demo.png')

#### kandinsky_v3_pokemon_blip

![example1](<>)
![example1](https://github.com/okotaku/diffengine/assets/24734142/8f078fa8-9485-40d9-8174-5996257aed88)
18 changes: 18 additions & 0 deletions configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,21 @@
"../_base_/schedules/stable_diffusion_xl_50e.py",
"../_base_/default_runtime.py",
]

optim_wrapper = dict(
_delete_=True,
optimizer=dict(
type="HybridAdam",
lr=1e-5,
weight_decay=1e-2),
accumulative_counts=4)

default_hooks = dict(
checkpoint=dict(save_param_scheduler=False)) # no scheduler in this config

runner_type = "FlexibleRunner"
strategy = dict(type="ColossalAIStrategy",
plugin=dict(type="LowLevelZeroPlugin",
stage=2,
precision="bf16",
max_norm=1.0))
13 changes: 5 additions & 8 deletions diffengine/models/editors/kandinsky/kandinskyv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(
gradient_checkpointing: bool = False,
enable_xformers: bool = False,
) -> None:
assert gradient_checkpointing is False, (
"KandinskyV3 does not support gradient checkpointing.")
if data_preprocessor is None:
data_preprocessor = {"type": "SDDataPreprocessor"}
if noise_generator is None:
Expand Down Expand Up @@ -129,13 +131,10 @@ def prepare_model(self) -> None:
Disable gradient for some models.
"""
if self.gradient_checkpointing:
self.unet.enable_gradient_checkpointing()

self.vae.requires_grad_(requires_grad=False)
print_log("Set VAE untrainable.", "current")
self.image_encoder.requires_grad_(requires_grad=False)
print_log("Set Image Encoder untrainable.", "current")
self.text_encoder.requires_grad_(requires_grad=False)
print_log("Set Text Encoder untrainable.", "current")

def set_xformers(self) -> None:
"""Set xformers for model."""
Expand Down Expand Up @@ -192,7 +191,7 @@ def infer(self,
if width is None:
width = 1024
pipeline = AutoPipelineForText2Image.from_pretrained(
self.decoder_model,
self.model,
movq=self.vae,
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
Expand Down Expand Up @@ -340,8 +339,6 @@ def forward(

encoder_hidden_states = self.text_encoder(
inputs["text"], attention_mask=inputs["attention_mask"])[0]
# encoder_hidden_states = encoder_hidden_states *
# inputs["attention_mask"].unsqueeze(2)

model_pred = self.unet(
noisy_latents,
Expand Down
230 changes: 230 additions & 0 deletions tests/test_models/test_editors/test_kandinsky/test_kandinsky_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from copy import deepcopy
from unittest import TestCase

import pytest
import torch
from diffusers import DDPMScheduler, Kandinsky3UNet, VQModel
from mmengine.optim import OptimWrapper
from torch import nn
from torch.optim import SGD
from transformers import AutoTokenizer, T5EncoderModel

from diffengine.models.editors import (
KandinskyV3,
SDDataPreprocessor,
)
from diffengine.models.losses import DeBiasEstimationLoss, L2Loss, SNRL2Loss
from diffengine.registry import MODELS


class DummyKandinskyV3(KandinskyV3):
def __init__(
self,
model: str = "kandinsky-community/kandinsky-3",
loss: dict | None = None,
unet_lora_config: dict | None = None,
prior_loss_weight: float = 1.,
tokenizer_max_length: int = 128,
prediction_type: str | None = None,
data_preprocessor: dict | nn.Module | None = None,
noise_generator: dict | None = None,
timesteps_generator: dict | None = None,
input_perturbation_gamma: float = 0.0,
*,
gradient_checkpointing: bool = False,
) -> None:
assert gradient_checkpointing is False, (
"KandinskyV3 does not support gradient checkpointing.")
if data_preprocessor is None:
data_preprocessor = {"type": "SDDataPreprocessor"}
if noise_generator is None:
noise_generator = {"type": "WhiteNoise"}
if timesteps_generator is None:
timesteps_generator = {"type": "TimeSteps"}
if loss is None:
loss = {"type": "L2Loss", "loss_weight": 1.0}
super(KandinskyV3, self).__init__(data_preprocessor=data_preprocessor)

self.model = model
self.unet_lora_config = deepcopy(unet_lora_config)
self.prior_loss_weight = prior_loss_weight
self.gradient_checkpointing = gradient_checkpointing
self.tokenizer_max_length = tokenizer_max_length
self.input_perturbation_gamma = input_perturbation_gamma

if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module: nn.Module = loss

self.prediction_type = prediction_type

self.tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-t5")

self.text_encoder = T5EncoderModel.from_pretrained(
"hf-internal-testing/tiny-random-t5")

self.scheduler = DDPMScheduler()

vae_kwargs = {
"block_out_channels": [32, 64],
"down_block_types": ["DownEncoderBlock2D",
"AttnDownEncoderBlock2D"],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 1,
"norm_num_groups": 8,
"norm_type": "spatial",
"num_vq_embeddings": 12,
"out_channels": 3,
"up_block_types": [
"AttnUpDecoderBlock2D",
"UpDecoderBlock2D",
],
"vq_embed_dim": 4,
}
self.vae = VQModel(**vae_kwargs)

self.unet = Kandinsky3UNet(
in_channels=4,
time_embedding_dim=4,
groups=2,
attention_head_dim=4,
layers_per_block=3,
block_out_channels=(32, 64),
cross_attention_dim=4,
encoder_hid_dim=32,
)
self.noise_generator = MODELS.build(noise_generator)
self.timesteps_generator = MODELS.build(timesteps_generator)

self.prepare_model()
self.set_lora()


class TestKandinskyV3(TestCase):

def test_init(self):
with pytest.raises(
AssertionError, match="KandinskyV3 does not support gradient"):
_ = DummyKandinskyV3(
loss=L2Loss(),
data_preprocessor=SDDataPreprocessor(),
gradient_checkpointing=True)

def test_train_step(self):
# test load with loss module
StableDiffuser = DummyKandinskyV3(
loss=L2Loss(),
data_preprocessor=SDDataPreprocessor())

# test train step
data = dict(
inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"],
clip_img=[torch.zeros((3, 224, 224))]))
optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optimizer)
log_vars = StableDiffuser.train_step(data, optim_wrapper)
assert log_vars
assert isinstance(log_vars["loss"], torch.Tensor)

def test_train_step_with_lora(self):
# test load with loss module
StableDiffuser = DummyKandinskyV3(
loss=L2Loss(),
unet_lora_config=dict(
type="LoRA", r=4,
target_modules=["to_q", "to_v", "to_k", "to_out.0"]),
data_preprocessor=SDDataPreprocessor())

# test train step
data = dict(
inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"],
clip_img=[torch.zeros((3, 224, 224))]))
optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optimizer)
log_vars = StableDiffuser.train_step(data, optim_wrapper)
assert log_vars
assert isinstance(log_vars["loss"], torch.Tensor)

def test_train_step_input_perturbation(self):
# test load with loss module
StableDiffuser = DummyKandinskyV3(
input_perturbation_gamma=0.1,
loss=L2Loss(),
data_preprocessor=SDDataPreprocessor())

# test train step
data = dict(
inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"],
clip_img=[torch.zeros((3, 224, 224))]))
optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optimizer)
log_vars = StableDiffuser.train_step(data, optim_wrapper)
assert log_vars
assert isinstance(log_vars["loss"], torch.Tensor)

def test_train_step_dreambooth(self):
# test load with loss module
StableDiffuser = DummyKandinskyV3(
loss=L2Loss(),
data_preprocessor=SDDataPreprocessor())

# test train step
data = dict(
inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a sks dog"],
clip_img=[torch.zeros((3, 224, 224))]))
data["inputs"]["result_class_image"] = dict(
img=[torch.zeros((3, 64, 64))],
text=["a dog"],
clip_img=[torch.zeros((3, 224, 224))]) # type: ignore[assignment]
optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optimizer)
log_vars = StableDiffuser.train_step(data, optim_wrapper)
assert log_vars
assert isinstance(log_vars["loss"], torch.Tensor)

def test_train_step_snr_loss(self):
# test load with loss module
StableDiffuser = DummyKandinskyV3(
loss=SNRL2Loss(),
data_preprocessor=SDDataPreprocessor())

# test train step
data = dict(
inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"],
clip_img=[torch.zeros((3, 224, 224))]))
optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optimizer)
log_vars = StableDiffuser.train_step(data, optim_wrapper)
assert log_vars
assert isinstance(log_vars["loss"], torch.Tensor)

def test_train_step_debias_estimation_loss(self):
# test load with loss module
StableDiffuser = DummyKandinskyV3(
loss=DeBiasEstimationLoss(),
data_preprocessor=SDDataPreprocessor())

# test train step
data = dict(
inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"],
clip_img=[torch.zeros((3, 224, 224))]))
optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optimizer)
log_vars = StableDiffuser.train_step(data, optim_wrapper)
assert log_vars
assert isinstance(log_vars["loss"], torch.Tensor)

def test_val_and_test_step(self):
StableDiffuser = DummyKandinskyV3(
loss=L2Loss(),
data_preprocessor=SDDataPreprocessor())

# test val_step
with pytest.raises(NotImplementedError, match="val_step is not"):
StableDiffuser.val_step(torch.zeros((1, )))

# test test_step
with pytest.raises(NotImplementedError, match="test_step is not"):
StableDiffuser.test_step(torch.zeros((1, )))
Loading

0 comments on commit 8b8d7e7

Please sign in to comment.