From 15dac8533fa69434f5a9b84e11adedbf78864e48 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 16 Feb 2024 22:46:13 +0000 Subject: [PATCH] Support IP-Adapter SigLIP --- .../pokemon_blip_kandinsky_decoder.py | 3 +- .../datasets/pokemon_blip_kandinsky_prior.py | 3 +- .../pokemon_blip_xl_ip_adapter_siglip_384.py | 60 +++++++++++++++++++ diffengine/configs/ip_adapter/README.md | 6 ++ ..._xl_pokemon_blip_ip_adapter_plus_siglip.py | 22 +++++++ diffengine/datasets/transforms/processing.py | 8 ++- 6 files changed, 97 insertions(+), 5 deletions(-) create mode 100644 diffengine/configs/_base_/datasets/pokemon_blip_xl_ip_adapter_siglip_384.py create mode 100644 diffengine/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter_plus_siglip.py diff --git a/diffengine/configs/_base_/datasets/pokemon_blip_kandinsky_decoder.py b/diffengine/configs/_base_/datasets/pokemon_blip_kandinsky_decoder.py index f95aae1..046159a 100644 --- a/diffengine/configs/_base_/datasets/pokemon_blip_kandinsky_decoder.py +++ b/diffengine/configs/_base_/datasets/pokemon_blip_kandinsky_decoder.py @@ -13,7 +13,8 @@ train_pipeline = [ dict(type=CLIPImageProcessor, - pretrained="kandinsky-community/kandinsky-2-2-prior"), + pretrained="kandinsky-community/kandinsky-2-2-prior", + subfolder="image_processor"), dict(type=TorchVisonTransformWrapper, transform=torchvision.transforms.Resize, size=768, interpolation="bicubic"), diff --git a/diffengine/configs/_base_/datasets/pokemon_blip_kandinsky_prior.py b/diffengine/configs/_base_/datasets/pokemon_blip_kandinsky_prior.py index fa3e565..ea9197c 100644 --- a/diffengine/configs/_base_/datasets/pokemon_blip_kandinsky_prior.py +++ b/diffengine/configs/_base_/datasets/pokemon_blip_kandinsky_prior.py @@ -6,7 +6,8 @@ train_pipeline = [ dict(type=CLIPImageProcessor, output_key="img", - pretrained="kandinsky-community/kandinsky-2-2-prior"), + pretrained="kandinsky-community/kandinsky-2-2-prior", + subfolder="image_processor"), dict(type=PackInputs), ] train_dataloader = dict( diff --git a/diffengine/configs/_base_/datasets/pokemon_blip_xl_ip_adapter_siglip_384.py b/diffengine/configs/_base_/datasets/pokemon_blip_xl_ip_adapter_siglip_384.py new file mode 100644 index 0000000..cb25966 --- /dev/null +++ b/diffengine/configs/_base_/datasets/pokemon_blip_xl_ip_adapter_siglip_384.py @@ -0,0 +1,60 @@ +import torchvision +from mmengine.dataset import DefaultSampler + +from diffengine.datasets import HFDataset +from diffengine.datasets.transforms import ( + ComputeTimeIds, + PackInputs, + RandomCrop, + RandomHorizontalFlip, + RandomTextDrop, + SaveImageShape, + TorchVisonTransformWrapper, + TransformersImageProcessor, +) +from diffengine.engine.hooks import IPAdapterSaveHook, VisualizationHook + +train_pipeline = [ + dict(type=SaveImageShape), + dict(type=TransformersImageProcessor, + pretrained="google/siglip-so400m-patch14-384"), + dict(type=RandomTextDrop), + dict(type=TorchVisonTransformWrapper, + transform=torchvision.transforms.Resize, + size=1024, interpolation="bilinear"), + dict(type=RandomCrop, size=1024), + dict(type=RandomHorizontalFlip, p=0.5), + dict(type=ComputeTimeIds), + dict(type=TorchVisonTransformWrapper, + transform=torchvision.transforms.ToTensor), + dict(type=TorchVisonTransformWrapper, + transform=torchvision.transforms.Normalize, mean=[0.5], std=[0.5]), + dict( + type=PackInputs, input_keys=["img", "text", "time_ids", "clip_img"]), +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + dataset=dict( + type=HFDataset, + dataset="lambdalabs/pokemon-blip-captions", + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = None +val_evaluator = None +test_dataloader = val_dataloader +test_evaluator = val_evaluator + +custom_hooks = [ + dict( + type=VisualizationHook, + prompt=["a drawing of a green pokemon with red eyes"] * 2 + [""] * 2, + example_image=[ + 'https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true' # noqa + ] * 4, + height=1024, + width=1024), + dict(type=IPAdapterSaveHook), +] diff --git a/diffengine/configs/ip_adapter/README.md b/diffengine/configs/ip_adapter/README.md index f773a3d..5dd631b 100644 --- a/diffengine/configs/ip_adapter/README.md +++ b/diffengine/configs/ip_adapter/README.md @@ -113,3 +113,9 @@ You can see more details on [`docs/source/run_guides/run_ip_adapter.md`](../../d ![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true) ![example1](https://github.com/okotaku/diffengine/assets/24734142/4b37ce6c-60fd-4456-a542-74163927ee01) + +#### stable_diffusion_xl_pokemon_blip_ip_adapter_plus_siglip + +![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true) + +![example1](https://github.com/okotaku/diffengine/assets/24734142/61e9279e-bd50-42b7-8a6f-1156a70466ea) diff --git a/diffengine/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter_plus_siglip.py b/diffengine/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter_plus_siglip.py new file mode 100644 index 0000000..1df92c0 --- /dev/null +++ b/diffengine/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter_plus_siglip.py @@ -0,0 +1,22 @@ +from mmengine.config import read_base +from transformers import AutoImageProcessor, SiglipVisionModel + +with read_base(): + from .._base_.datasets.pokemon_blip_xl_ip_adapter_siglip_384 import * + from .._base_.default_runtime import * + from .._base_.models.stable_diffusion_xl_ip_adapter_plus import * + from .._base_.schedules.stable_diffusion_xl_50e import * + + +model.image_encoder = dict( + type=SiglipVisionModel.from_pretrained, + pretrained_model_name_or_path="google/siglip-so400m-patch14-384") +model.feature_extractor = dict( + type=AutoImageProcessor.from_pretrained, + pretrained_model_name_or_path="google/siglip-so400m-patch14-384") + +train_dataloader.update(batch_size=1) + +optim_wrapper.update(accumulative_counts=4) # update every four times + +train_cfg.update(by_epoch=True, max_epochs=100) diff --git a/diffengine/datasets/transforms/processing.py b/diffengine/datasets/transforms/processing.py index d9b5669..cd1c75c 100644 --- a/diffengine/datasets/transforms/processing.py +++ b/diffengine/datasets/transforms/processing.py @@ -543,15 +543,17 @@ class CLIPImageProcessor(BaseTransform): results. Defaults to 'clip_img'. """ - def __init__(self, key: str = "img", output_key: str = "clip_img", - pretrained: str | None = None) -> None: + def __init__(self, key: str = "img", + output_key: str = "clip_img", + pretrained: str | None = None, + subfolder: str | None = None) -> None: self.key = key self.output_key = output_key if pretrained is None: self.pipeline = HFCLIPImageProcessor() else: self.pipeline = HFCLIPImageProcessor.from_pretrained( - pretrained, subfolder="image_processor") + pretrained, subfolder=subfolder) def transform(self, results: dict) -> dict | tuple[list, list] | None: """Transform.