Skip to content

Commit

Permalink
Merge pull request #102 from okotaku/feat/style_lora
Browse files Browse the repository at this point in the history
[Feature] Support Style LoRA
  • Loading branch information
okotaku authored Nov 29, 2023
2 parents 2888747 + 3307aec commit d9d223c
Show file tree
Hide file tree
Showing 24 changed files with 134 additions and 21 deletions.
38 changes: 38 additions & 0 deletions configs/_base_/datasets/cat_waterpainting_dreambooth_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
train_pipeline = [
dict(type="SaveImageShape"),
dict(type="torchvision/Resize", size=1024, interpolation="bilinear"),
dict(type="RandomCrop", size=1024),
dict(type="RandomHorizontalFlip", p=0.5),
dict(type="ComputeTimeIds"),
dict(type="torchvision/ToTensor"),
dict(type="DumpImage", max_imgs=5, dump_dir="work_dirs/dump"),
dict(type="torchvision/Normalize", mean=[0.5], std=[0.5]),
dict(type="PackInputs", input_keys=["img", "text", "time_ids"]),
]
train_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dict(
type="HFDreamBoothDataset",
dataset="data/cat_waterpainting",
instance_prompt="A cat in szn style",
pipeline=train_pipeline,
class_prompt=None),
sampler=dict(type="InfiniteSampler", shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type="VisualizationHook",
prompt=["A man in szn style"] * 4,
by_epoch=False,
interval=100,
height=1024,
width=1024),
dict(type="PeftSaveHook"),
]
37 changes: 37 additions & 0 deletions configs/_base_/datasets/waterpainting_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
train_pipeline = [
dict(type="SaveImageShape"),
dict(type="torchvision/Resize", size=1024, interpolation="bilinear"),
dict(type="RandomCrop", size=1024),
dict(type="RandomHorizontalFlip", p=0.5),
dict(type="ComputeTimeIds"),
dict(type="torchvision/ToTensor"),
dict(type="DumpImage", max_imgs=5, dump_dir="work_dirs/dump"),
dict(type="torchvision/Normalize", mean=[0.5], std=[0.5]),
dict(type="PackInputs", input_keys=["img", "text", "time_ids"]),
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
dataset=dict(
type="HFDataset",
dataset="data/waterpainting",
image_column="file_name",
pipeline=train_pipeline),
sampler=dict(type="InfiniteSampler", shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type="VisualizationHook",
prompt=["A man in szn style"] * 4,
by_epoch=False,
interval=100,
height=1024,
width=1024),
dict(type="PeftSaveHook"),
]
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times

train_cfg = dict(by_epoch=True, max_epochs=100)
2 changes: 1 addition & 1 deletion configs/lcm/lcm_xl_pokemon_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

train_dataloader = dict(batch_size=2)

optim_wrapper_cfg = dict(accumulative_counts=2) # update every four times
optim_wrapper = dict(accumulative_counts=2) # update every four times

custom_hooks = [
dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
15 changes: 15 additions & 0 deletions configs/stable_diffusion_xl_dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ Large text-to-image models achieved a remarkable leap in the evolution of AI, en
}
```

## Prepare Dataset

1. Download style data from [StyleDrop-PyTorch](https://github.com/aim-uofa/StyleDrop-PyTorch/tree/main/data).

2. Unzip the files as follows

```
data/cat_waterpainting
└── image_01_03.jpg
```

## Run Training

Run Training
Expand Down Expand Up @@ -102,6 +113,10 @@ Note that we use 2 GPUs for training all lora models.

![exampleback](https://github.com/okotaku/diffengine/assets/24734142/3dd06582-0eb6-4a10-a065-af5e3d452878)

#### stable_diffusion_xl_dreambooth_lora_cat_waterpainting

![examplewaterpaint](https://github.com/okotaku/diffengine/assets/24734142/d2a83569-d2da-4a24-ae2c-e89b1da93fbf)

## Acknowledgement

These experiments are based on [diffusers docs](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md). Thank you for the great experiments.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = [
"../_base_/models/stable_diffusion_xl_lora.py",
"../_base_/datasets/cat_waterpainting_dreambooth_xl.py",
"../_base_/schedules/stable_diffusion_500.py",
"../_base_/default_runtime.py",
]

train_dataloader = dict(
dataset=dict(class_image_config=dict(model={{_base_.model.model}})))

train_cfg = dict(max_iters=1000)
4 changes: 4 additions & 0 deletions configs/stable_diffusion_xl_lora/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,7 @@ You can see more details on [`docs/source/run_guides/run_lora_xl.md`](../../docs
#### stable_diffusion_xl_lora_conv3x3_pokemon_blip

![example3](https://github.com/okotaku/diffengine/assets/24734142/8ec55900-c08f-4e36-bd18-e6e81d35de6c)

#### stable_diffusion_xl_lora_waterpainting

![example4](https://github.com/okotaku/diffengine/assets/24734142/9bd797e3-07b5-452f-9f68-5bf017896c2f)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = [
"../_base_/models/stable_diffusion_xl_lora.py",
"../_base_/datasets/waterpainting_xl.py",
"../_base_/schedules/stable_diffusion_500.py",
"../_base_/default_runtime.py",
]

train_cfg = dict(max_iters=2000)
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
2 changes: 1 addition & 1 deletion configs/wuerstchen/wuerstchen_prior_pokemon_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
"../_base_/default_runtime.py",
]

optim_wrapper_cfg = dict(
optim_wrapper = dict(
optimizer=dict(lr=1e-5),
accumulative_counts=4) # update every four times
2 changes: 1 addition & 1 deletion docs/source/run_guides/run_ip_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ _base_ = [
train_dataloader = dict(batch_size=1)
optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
```

## Run training
Expand Down
2 changes: 1 addition & 1 deletion docs/source/run_guides/run_lcm.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ _base_ = [
train_dataloader = dict(batch_size=2)
optim_wrapper_cfg = dict(accumulative_counts=2) # update every four times
optim_wrapper = dict(accumulative_counts=2) # update every four times
custom_hooks = [
dict(
Expand Down
8 changes: 4 additions & 4 deletions docs/source/run_guides/run_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ _base_ = [
train_dataloader = dict(batch_size=1) # Because of GPU memory limit
optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
```

#### Finetuning the text encoder and UNet
Expand All @@ -35,7 +35,7 @@ _base_ = [
train_dataloader = dict(batch_size=1) # Because of GPU memory limit
optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
model = dict(finetune_text_encoder=True) # fine tune text encoder
```
Expand All @@ -54,7 +54,7 @@ _base_ = [
train_dataloader = dict(batch_size=1) # Because of GPU memory limit
optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
custom_hooks = [ # Hook is list, we should write all custom_hooks again.
dict(type='VisualizationHook', prompt=['yoda pokemon'] * 4),
Expand All @@ -77,7 +77,7 @@ _base_ = [
train_dataloader = dict(batch_size=1) # Because of GPU memory limit
optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
optim_wrapper = dict(accumulative_counts=4) # update every four times
model = dict(loss=dict(type='SNRL2Loss', snr_gamma=5.0, loss_weight=1.0)) # setup Min-SNR Weighting Strategy
```
Expand Down
2 changes: 1 addition & 1 deletion docs/source/run_guides/run_wuerstchen.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ _base_ = [
"../_base_/default_runtime.py",
]
optim_wrapper_cfg = dict(
optim_wrapper = dict(
optimizer=dict(lr=1e-5),
accumulative_counts=4) # update every four times
```
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"openmim>=0.3.9",
"datasets>=2.14.6",
"diffusers>=0.23.1",
"mmengine>=0.9.0",
"mmengine>=0.10.1",
"sentencepiece>=0.1.99",
"tqdm",
"transformers>=4.34.1",
Expand Down

0 comments on commit d9d223c

Please sign in to comment.