diff --git a/configs/_base_/datasets/cat_waterpainting_dreambooth_xl.py b/configs/_base_/datasets/cat_waterpainting_dreambooth_xl.py new file mode 100644 index 0000000..acfec80 --- /dev/null +++ b/configs/_base_/datasets/cat_waterpainting_dreambooth_xl.py @@ -0,0 +1,38 @@ +train_pipeline = [ + dict(type="SaveImageShape"), + dict(type="torchvision/Resize", size=1024, interpolation="bilinear"), + dict(type="RandomCrop", size=1024), + dict(type="RandomHorizontalFlip", p=0.5), + dict(type="ComputeTimeIds"), + dict(type="torchvision/ToTensor"), + dict(type="DumpImage", max_imgs=5, dump_dir="work_dirs/dump"), + dict(type="torchvision/Normalize", mean=[0.5], std=[0.5]), + dict(type="PackInputs", input_keys=["img", "text", "time_ids"]), +] +train_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dict( + type="HFDreamBoothDataset", + dataset="data/cat_waterpainting", + instance_prompt="A cat in szn style", + pipeline=train_pipeline, + class_prompt=None), + sampler=dict(type="InfiniteSampler", shuffle=True), +) + +val_dataloader = None +val_evaluator = None +test_dataloader = val_dataloader +test_evaluator = val_evaluator + +custom_hooks = [ + dict( + type="VisualizationHook", + prompt=["A man in szn style"] * 4, + by_epoch=False, + interval=100, + height=1024, + width=1024), + dict(type="PeftSaveHook"), +] diff --git a/configs/_base_/datasets/waterpainting_xl.py b/configs/_base_/datasets/waterpainting_xl.py new file mode 100644 index 0000000..10240bf --- /dev/null +++ b/configs/_base_/datasets/waterpainting_xl.py @@ -0,0 +1,37 @@ +train_pipeline = [ + dict(type="SaveImageShape"), + dict(type="torchvision/Resize", size=1024, interpolation="bilinear"), + dict(type="RandomCrop", size=1024), + dict(type="RandomHorizontalFlip", p=0.5), + dict(type="ComputeTimeIds"), + dict(type="torchvision/ToTensor"), + dict(type="DumpImage", max_imgs=5, dump_dir="work_dirs/dump"), + dict(type="torchvision/Normalize", mean=[0.5], std=[0.5]), + dict(type="PackInputs", input_keys=["img", "text", "time_ids"]), +] +train_dataloader = dict( + batch_size=2, + num_workers=4, + dataset=dict( + type="HFDataset", + dataset="data/waterpainting", + image_column="file_name", + pipeline=train_pipeline), + sampler=dict(type="InfiniteSampler", shuffle=True), +) + +val_dataloader = None +val_evaluator = None +test_dataloader = val_dataloader +test_evaluator = val_evaluator + +custom_hooks = [ + dict( + type="VisualizationHook", + prompt=["A man in szn style"] * 4, + by_epoch=False, + interval=100, + height=1024, + width=1024), + dict(type="PeftSaveHook"), +] diff --git a/configs/debias_estimation_loss/stable_diffusion_xl_pokemon_blip_debias_estimation_loss.py b/configs/debias_estimation_loss/stable_diffusion_xl_pokemon_blip_debias_estimation_loss.py index 7181258..5ca6b69 100644 --- a/configs/debias_estimation_loss/stable_diffusion_xl_pokemon_blip_debias_estimation_loss.py +++ b/configs/debias_estimation_loss/stable_diffusion_xl_pokemon_blip_debias_estimation_loss.py @@ -9,4 +9,4 @@ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times diff --git a/configs/input_perturbation/stable_diffusion_xl_pokemon_blip_input_perturbation.py b/configs/input_perturbation/stable_diffusion_xl_pokemon_blip_input_perturbation.py index d023662..6c69543 100644 --- a/configs/input_perturbation/stable_diffusion_xl_pokemon_blip_input_perturbation.py +++ b/configs/input_perturbation/stable_diffusion_xl_pokemon_blip_input_perturbation.py @@ -9,4 +9,4 @@ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times diff --git a/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter.py b/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter.py index 559891c..cb4e44f 100644 --- a/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter.py +++ b/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter.py @@ -7,4 +7,4 @@ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times diff --git a/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter_plus.py b/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter_plus.py index 6600c80..4079f3e 100644 --- a/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter_plus.py +++ b/configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter_plus.py @@ -7,6 +7,6 @@ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times train_cfg = dict(by_epoch=True, max_epochs=100) diff --git a/configs/lcm/lcm_xl_pokemon_blip.py b/configs/lcm/lcm_xl_pokemon_blip.py index 6597853..a6a7365 100644 --- a/configs/lcm/lcm_xl_pokemon_blip.py +++ b/configs/lcm/lcm_xl_pokemon_blip.py @@ -7,7 +7,7 @@ train_dataloader = dict(batch_size=2) -optim_wrapper_cfg = dict(accumulative_counts=2) # update every four times +optim_wrapper = dict(accumulative_counts=2) # update every four times custom_hooks = [ dict( diff --git a/configs/offset_noise/stable_diffusion_xl_pokemon_blip_offset_noise.py b/configs/offset_noise/stable_diffusion_xl_pokemon_blip_offset_noise.py index a434002..60c4955 100644 --- a/configs/offset_noise/stable_diffusion_xl_pokemon_blip_offset_noise.py +++ b/configs/offset_noise/stable_diffusion_xl_pokemon_blip_offset_noise.py @@ -9,4 +9,4 @@ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times diff --git a/configs/pyramid_noise/stable_diffusion_xl_pokemon_blip_pyramid_noise.py b/configs/pyramid_noise/stable_diffusion_xl_pokemon_blip_pyramid_noise.py index 12dc11e..00f1201 100644 --- a/configs/pyramid_noise/stable_diffusion_xl_pokemon_blip_pyramid_noise.py +++ b/configs/pyramid_noise/stable_diffusion_xl_pokemon_blip_pyramid_noise.py @@ -9,4 +9,4 @@ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times diff --git a/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip.py b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip.py index a4c65ea..dbce396 100644 --- a/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip.py +++ b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip.py @@ -7,4 +7,4 @@ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times diff --git a/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_pre_compute.py b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_pre_compute.py index 457af00..e6c8e84 100644 --- a/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_pre_compute.py +++ b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_pre_compute.py @@ -9,4 +9,4 @@ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times diff --git a/configs/stable_diffusion_xl_dreambooth/README.md b/configs/stable_diffusion_xl_dreambooth/README.md index 4ef1e94..d7dc823 100644 --- a/configs/stable_diffusion_xl_dreambooth/README.md +++ b/configs/stable_diffusion_xl_dreambooth/README.md @@ -22,6 +22,17 @@ Large text-to-image models achieved a remarkable leap in the evolution of AI, en } ``` +## Prepare Dataset + +1. Download style data from [StyleDrop-PyTorch](https://github.com/aim-uofa/StyleDrop-PyTorch/tree/main/data). + +2. Unzip the files as follows + +``` +data/cat_waterpainting +└── image_01_03.jpg +``` + ## Run Training Run Training @@ -102,6 +113,10 @@ Note that we use 2 GPUs for training all lora models. ![exampleback](https://github.com/okotaku/diffengine/assets/24734142/3dd06582-0eb6-4a10-a065-af5e3d452878) +#### stable_diffusion_xl_dreambooth_lora_cat_waterpainting + +![examplewaterpaint](https://github.com/okotaku/diffengine/assets/24734142/d2a83569-d2da-4a24-ae2c-e89b1da93fbf) + ## Acknowledgement These experiments are based on [diffusers docs](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md). Thank you for the great experiments. diff --git a/configs/stable_diffusion_xl_dreambooth/stable_diffusion_xl_dreambooth_lora_cat_waterpainting.py b/configs/stable_diffusion_xl_dreambooth/stable_diffusion_xl_dreambooth_lora_cat_waterpainting.py new file mode 100644 index 0000000..f7bdd98 --- /dev/null +++ b/configs/stable_diffusion_xl_dreambooth/stable_diffusion_xl_dreambooth_lora_cat_waterpainting.py @@ -0,0 +1,11 @@ +_base_ = [ + "../_base_/models/stable_diffusion_xl_lora.py", + "../_base_/datasets/cat_waterpainting_dreambooth_xl.py", + "../_base_/schedules/stable_diffusion_500.py", + "../_base_/default_runtime.py", +] + +train_dataloader = dict( + dataset=dict(class_image_config=dict(model={{_base_.model.model}}))) + +train_cfg = dict(max_iters=1000) diff --git a/configs/stable_diffusion_xl_lora/README.md b/configs/stable_diffusion_xl_lora/README.md index 75f65d1..d1cf941 100644 --- a/configs/stable_diffusion_xl_lora/README.md +++ b/configs/stable_diffusion_xl_lora/README.md @@ -93,3 +93,7 @@ You can see more details on [`docs/source/run_guides/run_lora_xl.md`](../../docs #### stable_diffusion_xl_lora_conv3x3_pokemon_blip ![example3](https://github.com/okotaku/diffengine/assets/24734142/8ec55900-c08f-4e36-bd18-e6e81d35de6c) + +#### stable_diffusion_xl_lora_waterpainting + +![example4](https://github.com/okotaku/diffengine/assets/24734142/9bd797e3-07b5-452f-9f68-5bf017896c2f) diff --git a/configs/stable_diffusion_xl_lora/stable_diffusion_xl_lora_waterpainting.py b/configs/stable_diffusion_xl_lora/stable_diffusion_xl_lora_waterpainting.py new file mode 100644 index 0000000..a790b81 --- /dev/null +++ b/configs/stable_diffusion_xl_lora/stable_diffusion_xl_lora_waterpainting.py @@ -0,0 +1,8 @@ +_base_ = [ + "../_base_/models/stable_diffusion_xl_lora.py", + "../_base_/datasets/waterpainting_xl.py", + "../_base_/schedules/stable_diffusion_500.py", + "../_base_/default_runtime.py", +] + +train_cfg = dict(max_iters=2000) diff --git a/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_earlier_bias.py b/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_earlier_bias.py index 7e96af1..45c8751 100644 --- a/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_earlier_bias.py +++ b/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_earlier_bias.py @@ -9,4 +9,4 @@ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times diff --git a/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_later_bias.py b/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_later_bias.py index c871175..09a2caf 100644 --- a/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_later_bias.py +++ b/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_later_bias.py @@ -9,4 +9,4 @@ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times diff --git a/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_range_bias.py b/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_range_bias.py index a13c953..9d0c862 100644 --- a/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_range_bias.py +++ b/configs/timesteps_bias/stable_diffusion_xl_pokemon_blip_range_bias.py @@ -9,4 +9,4 @@ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times diff --git a/configs/wuerstchen/wuerstchen_prior_pokemon_blip.py b/configs/wuerstchen/wuerstchen_prior_pokemon_blip.py index e8396eb..891fde6 100644 --- a/configs/wuerstchen/wuerstchen_prior_pokemon_blip.py +++ b/configs/wuerstchen/wuerstchen_prior_pokemon_blip.py @@ -5,6 +5,6 @@ "../_base_/default_runtime.py", ] -optim_wrapper_cfg = dict( +optim_wrapper = dict( optimizer=dict(lr=1e-5), accumulative_counts=4) # update every four times diff --git a/docs/source/run_guides/run_ip_adapter.md b/docs/source/run_guides/run_ip_adapter.md index 33c235c..cbccebf 100644 --- a/docs/source/run_guides/run_ip_adapter.md +++ b/docs/source/run_guides/run_ip_adapter.md @@ -18,7 +18,7 @@ _base_ = [ train_dataloader = dict(batch_size=1) -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times ``` ## Run training diff --git a/docs/source/run_guides/run_lcm.md b/docs/source/run_guides/run_lcm.md index f0cee8e..61863d1 100644 --- a/docs/source/run_guides/run_lcm.md +++ b/docs/source/run_guides/run_lcm.md @@ -18,7 +18,7 @@ _base_ = [ train_dataloader = dict(batch_size=2) -optim_wrapper_cfg = dict(accumulative_counts=2) # update every four times +optim_wrapper = dict(accumulative_counts=2) # update every four times custom_hooks = [ dict( diff --git a/docs/source/run_guides/run_sdxl.md b/docs/source/run_guides/run_sdxl.md index 3f14227..d611b78 100644 --- a/docs/source/run_guides/run_sdxl.md +++ b/docs/source/run_guides/run_sdxl.md @@ -18,7 +18,7 @@ _base_ = [ train_dataloader = dict(batch_size=1) # Because of GPU memory limit -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times ``` #### Finetuning the text encoder and UNet @@ -35,7 +35,7 @@ _base_ = [ train_dataloader = dict(batch_size=1) # Because of GPU memory limit -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times model = dict(finetune_text_encoder=True) # fine tune text encoder ``` @@ -54,7 +54,7 @@ _base_ = [ train_dataloader = dict(batch_size=1) # Because of GPU memory limit -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times custom_hooks = [ # Hook is list, we should write all custom_hooks again. dict(type='VisualizationHook', prompt=['yoda pokemon'] * 4), @@ -77,7 +77,7 @@ _base_ = [ train_dataloader = dict(batch_size=1) # Because of GPU memory limit -optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times +optim_wrapper = dict(accumulative_counts=4) # update every four times model = dict(loss=dict(type='SNRL2Loss', snr_gamma=5.0, loss_weight=1.0)) # setup Min-SNR Weighting Strategy ``` diff --git a/docs/source/run_guides/run_wuerstchen.md b/docs/source/run_guides/run_wuerstchen.md index c13614a..1385d8c 100644 --- a/docs/source/run_guides/run_wuerstchen.md +++ b/docs/source/run_guides/run_wuerstchen.md @@ -16,7 +16,7 @@ _base_ = [ "../_base_/default_runtime.py", ] -optim_wrapper_cfg = dict( +optim_wrapper = dict( optimizer=dict(lr=1e-5), accumulative_counts=4) # update every four times ``` diff --git a/pyproject.toml b/pyproject.toml index 2f6e4f7..7a44f0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "openmim>=0.3.9", "datasets>=2.14.6", "diffusers>=0.23.1", - "mmengine>=0.9.0", + "mmengine>=0.10.1", "sentencepiece>=0.1.99", "tqdm", "transformers>=4.34.1",