From 8b8d7e74d67cae8b5457ad71297d0953e3bbfe22 Mon Sep 17 00:00:00 2001 From: okotaku Date: Tue, 26 Dec 2023 01:39:43 +0000 Subject: [PATCH] Support KandinskyV3 --- README.md | 1 + .../datasets/pokemon_blip_kandinsky_v3.py | 2 +- configs/_base_/models/kandinsky_v3.py | 3 +- configs/kandinsky_v3/README.md | 11 +- .../kandinsky_v3/kandinsky_v3_pokemon_blip.py | 18 ++ .../models/editors/kandinsky/kandinskyv3.py | 13 +- .../test_kandinsky/test_kandinsky_v3.py | 230 ++++++++++++++++++ .../publish_model2diffusers.py | 21 +- 8 files changed, 278 insertions(+), 21 deletions(-) create mode 100644 tests/test_models/test_editors/test_kandinsky/test_kandinsky_v3.py diff --git a/README.md b/README.md index ecfb7ff..239051c 100644 --- a/README.md +++ b/README.md @@ -276,6 +276,7 @@ For detailed user guides and advanced guides, please refer to our [Documentation diff --git a/configs/_base_/datasets/pokemon_blip_kandinsky_v3.py b/configs/_base_/datasets/pokemon_blip_kandinsky_v3.py index c8c012e..6f623b1 100644 --- a/configs/_base_/datasets/pokemon_blip_kandinsky_v3.py +++ b/configs/_base_/datasets/pokemon_blip_kandinsky_v3.py @@ -7,7 +7,7 @@ dict(type="PackInputs"), ] train_dataloader = dict( - batch_size=1, + batch_size=2, num_workers=4, dataset=dict( type="HFDataset", diff --git a/configs/_base_/models/kandinsky_v3.py b/configs/_base_/models/kandinsky_v3.py index a81908d..6ac90bf 100644 --- a/configs/_base_/models/kandinsky_v3.py +++ b/configs/_base_/models/kandinsky_v3.py @@ -1,4 +1,3 @@ model = dict( type="KandinskyV3", - model="kandinsky-community/kandinsky-3", - gradient_checkpointing=True) + model="kandinsky-community/kandinsky-3") diff --git a/configs/kandinsky_v3/README.md b/configs/kandinsky_v3/README.md index 6264580..9239b2d 100644 --- a/configs/kandinsky_v3/README.md +++ b/configs/kandinsky_v3/README.md @@ -7,7 +7,7 @@ We present Kandinsky 3.0, a large-scale text-to-image generation model based on latent diffusion, continuing the series of text-to-image Kandinsky models and reflecting our progress to achieve higher quality and realism of image generation. Compared to previous versions of Kandinsky 2.x, Kandinsky 3.0 leverages a two times larger UNet backbone, a ten times larger text encoder and remove diffusion mapping. We describe the architecture of the model, the data collection procedure, the training technique, the production system of user interaction. We focus on the key components that, as we have identified as a result of a large number of experiments, had the most significant impact on improving the quality of our model in comparison with the other ones. By results of our side by side comparisons Kandinsky become better in text understanding and works better on specific domains.
- +
## Citation @@ -46,24 +46,23 @@ Before inferencing, we should convert weights for diffusers format, ```bash $ mim run diffengine publish_model2diffusers ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS} # Example -$ mim run diffengine publish_model2diffusers configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py work_dirs/kandinsky_v3_pokemon_blip/epoch_50.pth work_dirs/kandinsky_v3_pokemon_blip --save-keys unet +# Note that when training colossalai, use `--colossalai` and set `INPUT_FILENAME` to index file. +$ mim run diffengine publish_model2diffusers configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py work_dirs/kandinsky_v3_pokemon_blip/epoch_50.pth/model/pytorch_model.bin.index.json work_dirs/kandinsky_v3_pokemon_blip --save-keys unet --colossalai ``` Then we can run inference. ```py -import torch from diffusers import AutoPipelineForText2Image, Kandinsky3UNet prompt = 'yoda pokemon' checkpoint = 'work_dirs/kandinsky_v3_pokemon_blip' unet = Kandinsky3UNet.from_pretrained( - checkpoint, subfolder='unet', torch_dtype=torch.float16) + checkpoint, subfolder='unet') pipe = AutoPipelineForText2Image.from_pretrained( "kandinsky-community/kandinsky-3", unet=unet, - torch_dtype=torch.float16, variant="fp16", ) pipe.to('cuda') @@ -81,4 +80,4 @@ image.save('demo.png') #### kandinsky_v3_pokemon_blip -![example1](<>) +![example1](https://github.com/okotaku/diffengine/assets/24734142/8f078fa8-9485-40d9-8174-5996257aed88) diff --git a/configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py b/configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py index 3acc513..2d8e5e1 100644 --- a/configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py +++ b/configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py @@ -4,3 +4,21 @@ "../_base_/schedules/stable_diffusion_xl_50e.py", "../_base_/default_runtime.py", ] + +optim_wrapper = dict( + _delete_=True, + optimizer=dict( + type="HybridAdam", + lr=1e-5, + weight_decay=1e-2), + accumulative_counts=4) + +default_hooks = dict( + checkpoint=dict(save_param_scheduler=False)) # no scheduler in this config + +runner_type = "FlexibleRunner" +strategy = dict(type="ColossalAIStrategy", + plugin=dict(type="LowLevelZeroPlugin", + stage=2, + precision="bf16", + max_norm=1.0)) diff --git a/diffengine/models/editors/kandinsky/kandinskyv3.py b/diffengine/models/editors/kandinsky/kandinskyv3.py index a4f4b29..291e3da 100644 --- a/diffengine/models/editors/kandinsky/kandinskyv3.py +++ b/diffengine/models/editors/kandinsky/kandinskyv3.py @@ -74,6 +74,8 @@ def __init__( gradient_checkpointing: bool = False, enable_xformers: bool = False, ) -> None: + assert gradient_checkpointing is False, ( + "KandinskyV3 does not support gradient checkpointing.") if data_preprocessor is None: data_preprocessor = {"type": "SDDataPreprocessor"} if noise_generator is None: @@ -129,13 +131,10 @@ def prepare_model(self) -> None: Disable gradient for some models. """ - if self.gradient_checkpointing: - self.unet.enable_gradient_checkpointing() - self.vae.requires_grad_(requires_grad=False) print_log("Set VAE untrainable.", "current") - self.image_encoder.requires_grad_(requires_grad=False) - print_log("Set Image Encoder untrainable.", "current") + self.text_encoder.requires_grad_(requires_grad=False) + print_log("Set Text Encoder untrainable.", "current") def set_xformers(self) -> None: """Set xformers for model.""" @@ -192,7 +191,7 @@ def infer(self, if width is None: width = 1024 pipeline = AutoPipelineForText2Image.from_pretrained( - self.decoder_model, + self.model, movq=self.vae, tokenizer=self.tokenizer, text_encoder=self.text_encoder, @@ -340,8 +339,6 @@ def forward( encoder_hidden_states = self.text_encoder( inputs["text"], attention_mask=inputs["attention_mask"])[0] - # encoder_hidden_states = encoder_hidden_states * - # inputs["attention_mask"].unsqueeze(2) model_pred = self.unet( noisy_latents, diff --git a/tests/test_models/test_editors/test_kandinsky/test_kandinsky_v3.py b/tests/test_models/test_editors/test_kandinsky/test_kandinsky_v3.py new file mode 100644 index 0000000..f0d2304 --- /dev/null +++ b/tests/test_models/test_editors/test_kandinsky/test_kandinsky_v3.py @@ -0,0 +1,230 @@ +from copy import deepcopy +from unittest import TestCase + +import pytest +import torch +from diffusers import DDPMScheduler, Kandinsky3UNet, VQModel +from mmengine.optim import OptimWrapper +from torch import nn +from torch.optim import SGD +from transformers import AutoTokenizer, T5EncoderModel + +from diffengine.models.editors import ( + KandinskyV3, + SDDataPreprocessor, +) +from diffengine.models.losses import DeBiasEstimationLoss, L2Loss, SNRL2Loss +from diffengine.registry import MODELS + + +class DummyKandinskyV3(KandinskyV3): + def __init__( + self, + model: str = "kandinsky-community/kandinsky-3", + loss: dict | None = None, + unet_lora_config: dict | None = None, + prior_loss_weight: float = 1., + tokenizer_max_length: int = 128, + prediction_type: str | None = None, + data_preprocessor: dict | nn.Module | None = None, + noise_generator: dict | None = None, + timesteps_generator: dict | None = None, + input_perturbation_gamma: float = 0.0, + *, + gradient_checkpointing: bool = False, + ) -> None: + assert gradient_checkpointing is False, ( + "KandinskyV3 does not support gradient checkpointing.") + if data_preprocessor is None: + data_preprocessor = {"type": "SDDataPreprocessor"} + if noise_generator is None: + noise_generator = {"type": "WhiteNoise"} + if timesteps_generator is None: + timesteps_generator = {"type": "TimeSteps"} + if loss is None: + loss = {"type": "L2Loss", "loss_weight": 1.0} + super(KandinskyV3, self).__init__(data_preprocessor=data_preprocessor) + + self.model = model + self.unet_lora_config = deepcopy(unet_lora_config) + self.prior_loss_weight = prior_loss_weight + self.gradient_checkpointing = gradient_checkpointing + self.tokenizer_max_length = tokenizer_max_length + self.input_perturbation_gamma = input_perturbation_gamma + + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module: nn.Module = loss + + self.prediction_type = prediction_type + + self.tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-t5") + + self.text_encoder = T5EncoderModel.from_pretrained( + "hf-internal-testing/tiny-random-t5") + + self.scheduler = DDPMScheduler() + + vae_kwargs = { + "block_out_channels": [32, 64], + "down_block_types": ["DownEncoderBlock2D", + "AttnDownEncoderBlock2D"], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 1, + "norm_num_groups": 8, + "norm_type": "spatial", + "num_vq_embeddings": 12, + "out_channels": 3, + "up_block_types": [ + "AttnUpDecoderBlock2D", + "UpDecoderBlock2D", + ], + "vq_embed_dim": 4, + } + self.vae = VQModel(**vae_kwargs) + + self.unet = Kandinsky3UNet( + in_channels=4, + time_embedding_dim=4, + groups=2, + attention_head_dim=4, + layers_per_block=3, + block_out_channels=(32, 64), + cross_attention_dim=4, + encoder_hid_dim=32, + ) + self.noise_generator = MODELS.build(noise_generator) + self.timesteps_generator = MODELS.build(timesteps_generator) + + self.prepare_model() + self.set_lora() + + +class TestKandinskyV3(TestCase): + + def test_init(self): + with pytest.raises( + AssertionError, match="KandinskyV3 does not support gradient"): + _ = DummyKandinskyV3( + loss=L2Loss(), + data_preprocessor=SDDataPreprocessor(), + gradient_checkpointing=True) + + def test_train_step(self): + # test load with loss module + StableDiffuser = DummyKandinskyV3( + loss=L2Loss(), + data_preprocessor=SDDataPreprocessor()) + + # test train step + data = dict( + inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"], + clip_img=[torch.zeros((3, 224, 224))])) + optimizer = SGD(StableDiffuser.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer) + log_vars = StableDiffuser.train_step(data, optim_wrapper) + assert log_vars + assert isinstance(log_vars["loss"], torch.Tensor) + + def test_train_step_with_lora(self): + # test load with loss module + StableDiffuser = DummyKandinskyV3( + loss=L2Loss(), + unet_lora_config=dict( + type="LoRA", r=4, + target_modules=["to_q", "to_v", "to_k", "to_out.0"]), + data_preprocessor=SDDataPreprocessor()) + + # test train step + data = dict( + inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"], + clip_img=[torch.zeros((3, 224, 224))])) + optimizer = SGD(StableDiffuser.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer) + log_vars = StableDiffuser.train_step(data, optim_wrapper) + assert log_vars + assert isinstance(log_vars["loss"], torch.Tensor) + + def test_train_step_input_perturbation(self): + # test load with loss module + StableDiffuser = DummyKandinskyV3( + input_perturbation_gamma=0.1, + loss=L2Loss(), + data_preprocessor=SDDataPreprocessor()) + + # test train step + data = dict( + inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"], + clip_img=[torch.zeros((3, 224, 224))])) + optimizer = SGD(StableDiffuser.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer) + log_vars = StableDiffuser.train_step(data, optim_wrapper) + assert log_vars + assert isinstance(log_vars["loss"], torch.Tensor) + + def test_train_step_dreambooth(self): + # test load with loss module + StableDiffuser = DummyKandinskyV3( + loss=L2Loss(), + data_preprocessor=SDDataPreprocessor()) + + # test train step + data = dict( + inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a sks dog"], + clip_img=[torch.zeros((3, 224, 224))])) + data["inputs"]["result_class_image"] = dict( + img=[torch.zeros((3, 64, 64))], + text=["a dog"], + clip_img=[torch.zeros((3, 224, 224))]) # type: ignore[assignment] + optimizer = SGD(StableDiffuser.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer) + log_vars = StableDiffuser.train_step(data, optim_wrapper) + assert log_vars + assert isinstance(log_vars["loss"], torch.Tensor) + + def test_train_step_snr_loss(self): + # test load with loss module + StableDiffuser = DummyKandinskyV3( + loss=SNRL2Loss(), + data_preprocessor=SDDataPreprocessor()) + + # test train step + data = dict( + inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"], + clip_img=[torch.zeros((3, 224, 224))])) + optimizer = SGD(StableDiffuser.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer) + log_vars = StableDiffuser.train_step(data, optim_wrapper) + assert log_vars + assert isinstance(log_vars["loss"], torch.Tensor) + + def test_train_step_debias_estimation_loss(self): + # test load with loss module + StableDiffuser = DummyKandinskyV3( + loss=DeBiasEstimationLoss(), + data_preprocessor=SDDataPreprocessor()) + + # test train step + data = dict( + inputs=dict(img=[torch.zeros((3, 64, 64))], text=["a dog"], + clip_img=[torch.zeros((3, 224, 224))])) + optimizer = SGD(StableDiffuser.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer) + log_vars = StableDiffuser.train_step(data, optim_wrapper) + assert log_vars + assert isinstance(log_vars["loss"], torch.Tensor) + + def test_val_and_test_step(self): + StableDiffuser = DummyKandinskyV3( + loss=L2Loss(), + data_preprocessor=SDDataPreprocessor()) + + # test val_step + with pytest.raises(NotImplementedError, match="val_step is not"): + StableDiffuser.val_step(torch.zeros((1, ))) + + # test test_step + with pytest.raises(NotImplementedError, match="test_step is not"): + StableDiffuser.test_step(torch.zeros((1, ))) diff --git a/tools/model_converters/publish_model2diffusers.py b/tools/model_converters/publish_model2diffusers.py index da4f5e7..41d6c73 100644 --- a/tools/model_converters/publish_model2diffusers.py +++ b/tools/model_converters/publish_model2diffusers.py @@ -1,5 +1,6 @@ import argparse import os.path as osp +from pathlib import Path import torch from mmengine.config import Config @@ -20,6 +21,10 @@ def parse_args(): # noqa type=str, default=["unet", "text_encoder", "transformer"], help="keys to save in the published checkpoint") + parser.add_argument( + "--colossalai", + action="store_true", + help="whether the checkpoint is trained with colossalai") return parser.parse_args() @@ -44,15 +49,23 @@ def main() -> None: cfg.work_dir = osp.join("./work_dirs", osp.splitext(osp.basename(args.config))[0]) + if args.colossalai: + cfg.strategy = None + cfg.pop("runner_type") + # build the runner from config runner = ( Runner.from_cfg(cfg) if "runner_type" not in cfg else RUNNERS.build(cfg)) - state_dict = torch.load(args.in_file) - if "state_dict" in state_dict: - state_dict = state_dict["state_dict"] - runner.model.load_state_dict(state_dict, strict=False) + if args.colossalai: + from colossalai.checkpoint_io import GeneralCheckpointIO + GeneralCheckpointIO().load_sharded_model(runner.model, Path(args.in_file)) + else: + state_dict = torch.load(args.in_file) + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + runner.model.load_state_dict(state_dict, strict=False) process_checkpoint(runner, args.out_dir, args.save_keys)