Skip to content

Commit

Permalink
Support IP-Adapter SigLIP
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku committed Feb 16, 2024
1 parent 4234611 commit 15dac85
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

train_pipeline = [
dict(type=CLIPImageProcessor,
pretrained="kandinsky-community/kandinsky-2-2-prior"),
pretrained="kandinsky-community/kandinsky-2-2-prior",
subfolder="image_processor"),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Resize,
size=768, interpolation="bicubic"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

train_pipeline = [
dict(type=CLIPImageProcessor, output_key="img",
pretrained="kandinsky-community/kandinsky-2-2-prior"),
pretrained="kandinsky-community/kandinsky-2-2-prior",
subfolder="image_processor"),
dict(type=PackInputs),
]
train_dataloader = dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torchvision
from mmengine.dataset import DefaultSampler

from diffengine.datasets import HFDataset
from diffengine.datasets.transforms import (
ComputeTimeIds,
PackInputs,
RandomCrop,
RandomHorizontalFlip,
RandomTextDrop,
SaveImageShape,
TorchVisonTransformWrapper,
TransformersImageProcessor,
)
from diffengine.engine.hooks import IPAdapterSaveHook, VisualizationHook

train_pipeline = [
dict(type=SaveImageShape),
dict(type=TransformersImageProcessor,
pretrained="google/siglip-so400m-patch14-384"),
dict(type=RandomTextDrop),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Resize,
size=1024, interpolation="bilinear"),
dict(type=RandomCrop, size=1024),
dict(type=RandomHorizontalFlip, p=0.5),
dict(type=ComputeTimeIds),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.ToTensor),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Normalize, mean=[0.5], std=[0.5]),
dict(
type=PackInputs, input_keys=["img", "text", "time_ids", "clip_img"]),
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(
type=HFDataset,
dataset="lambdalabs/pokemon-blip-captions",
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type=VisualizationHook,
prompt=["a drawing of a green pokemon with red eyes"] * 2 + [""] * 2,
example_image=[
'https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true' # noqa
] * 4,
height=1024,
width=1024),
dict(type=IPAdapterSaveHook),
]
6 changes: 6 additions & 0 deletions diffengine/configs/ip_adapter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,9 @@ You can see more details on [`docs/source/run_guides/run_ip_adapter.md`](../../d
![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/4b37ce6c-60fd-4456-a542-74163927ee01)

#### stable_diffusion_xl_pokemon_blip_ip_adapter_plus_siglip

![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/61e9279e-bd50-42b7-8a6f-1156a70466ea)
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from mmengine.config import read_base
from transformers import AutoImageProcessor, SiglipVisionModel

with read_base():
from .._base_.datasets.pokemon_blip_xl_ip_adapter_siglip_384 import *
from .._base_.default_runtime import *
from .._base_.models.stable_diffusion_xl_ip_adapter_plus import *
from .._base_.schedules.stable_diffusion_xl_50e import *


model.image_encoder = dict(
type=SiglipVisionModel.from_pretrained,
pretrained_model_name_or_path="google/siglip-so400m-patch14-384")
model.feature_extractor = dict(
type=AutoImageProcessor.from_pretrained,
pretrained_model_name_or_path="google/siglip-so400m-patch14-384")

train_dataloader.update(batch_size=1)

optim_wrapper.update(accumulative_counts=4) # update every four times

train_cfg.update(by_epoch=True, max_epochs=100)
8 changes: 5 additions & 3 deletions diffengine/datasets/transforms/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,17 @@ class CLIPImageProcessor(BaseTransform):
results. Defaults to 'clip_img'.
"""

def __init__(self, key: str = "img", output_key: str = "clip_img",
pretrained: str | None = None) -> None:
def __init__(self, key: str = "img",
output_key: str = "clip_img",
pretrained: str | None = None,
subfolder: str | None = None) -> None:
self.key = key
self.output_key = output_key
if pretrained is None:
self.pipeline = HFCLIPImageProcessor()
else:
self.pipeline = HFCLIPImageProcessor.from_pretrained(
pretrained, subfolder="image_processor")
pretrained, subfolder=subfolder)

def transform(self, results: dict) -> dict | tuple[list, list] | None:
"""Transform.
Expand Down

0 comments on commit 15dac85

Please sign in to comment.