diff --git a/Dockerfile b/Dockerfile index b782e90..fc9e771 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/pytorch:23.10-py3 +FROM nvcr.io/nvidia/pytorch:23.11-py3 RUN apt update -y && apt install -y \ git tmux diff --git a/configs/stable_diffusion_xl/README.md b/configs/stable_diffusion_xl/README.md index ee0e88a..4b6667c 100644 --- a/configs/stable_diffusion_xl/README.md +++ b/configs/stable_diffusion_xl/README.md @@ -31,6 +31,8 @@ $ mim train diffengine configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_b ## Training Speed +#### Single GPU + Environment: - A6000 Single GPU @@ -48,6 +50,24 @@ Settings: Note that `stable_diffusion_xl_pokemon_blip_fast` took a few minutes to compile. We will disregard it. +#### Multiple GPUs + +Environment: + +- A100 x 4 GPUs +- nvcr.io/nvidia/pytorch:23.11-py3 + +Settings: + +- 1epoch training. + +| Model | total time | +| :-------------------------------------------------------: | :--------: | +| stable_diffusion_xl_pokemon_blip_fast (BS=4) | 1 m 6 s | +| stable_diffusion_xl_pokemon_blip_deepspeed_stage3 (BS=8) | 1 m 5 s | +| stable_diffusion_xl_pokemon_blip_deepspeed_stage2 (BS=8) | 58 s | +| stable_diffusion_xl_pokemon_blip_colossal (stage=2, BS=8) | 58s | + ## Inference with diffusers Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module. diff --git a/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_colossal.py b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_colossal.py new file mode 100644 index 0000000..ddde9e4 --- /dev/null +++ b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_colossal.py @@ -0,0 +1,44 @@ +_base_ = [ + "../_base_/models/stable_diffusion_xl.py", + "../_base_/datasets/pokemon_blip_xl.py", + "../_base_/schedules/stable_diffusion_xl_50e.py", + "../_base_/default_runtime.py", +] + +model = dict( + gradient_checkpointing=False) + +train_dataloader = dict(batch_size=8, num_workers=8) + +optim_wrapper = dict( + _delete_=True, + optimizer=dict( + type="HybridAdam", + lr=1e-5, + weight_decay=1e-2), + accumulative_counts=4) + +env_cfg = dict( + cudnn_benchmark=True, +) + +custom_hooks = [ + dict( + type="VisualizationHook", + prompt=["yoda pokemon"] * 4, + height=1024, + width=1024), + dict(type="SDCheckpointHook"), + dict(type="FastNormHook", fuse_main_ln=False, fuse_gn=False), + dict(type="CompileHook", compile_main=True), +] + +default_hooks = dict( + checkpoint=dict(save_param_scheduler=False)) # no scheduler in this config + +runner_type = "FlexibleRunner" +strategy = dict(type="ColossalAIStrategy", + mixed_precision="fp16", + plugin=dict(type="LowLevelZeroPlugin", + stage=2, + max_norm=1.0)) diff --git a/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_deepspeed_stage2.py b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_deepspeed_stage2.py new file mode 100644 index 0000000..7fa6705 --- /dev/null +++ b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_deepspeed_stage2.py @@ -0,0 +1,60 @@ +_base_ = [ + "../_base_/models/stable_diffusion_xl.py", + "../_base_/datasets/pokemon_blip_xl.py", + "../_base_/schedules/stable_diffusion_xl_50e.py", + "../_base_/default_runtime.py", +] + +model = dict( + enable_xformers=True, + gradient_checkpointing=False) + +train_dataloader = dict(batch_size=8, num_workers=8) + +optim_wrapper = dict( + _delete_=True, + type="DeepSpeedOptimWrapper", + optimizer=dict( + type="FusedAdam", + lr=1e-5, + weight_decay=1e-2)) + +env_cfg = dict( + cudnn_benchmark=True, +) + +custom_hooks = [ + dict( + type="VisualizationHook", + prompt=["yoda pokemon"] * 4, + height=1024, + width=1024), + dict(type="SDCheckpointHook"), + dict(type="FastNormHook", fuse_main_ln=False, fuse_gn=False), +] + +runner_type = "FlexibleRunner" +strategy = dict( + type="DeepSpeedStrategy", + gradient_clipping=1.0, + gradient_accumulation_steps=4, + fp16=dict( + enabled=True, + fp16_master_weights_and_grads=False, + loss_scale=0, + loss_scale_window=500, + hysteresis=2, + min_loss_scale=1, + initial_scale_power=15, + ), + inputs_to_half=["inputs"], + zero_optimization=dict( + stage=2, + allgather_partitions=True, + reduce_scatter=True, + allgather_bucket_size=50000000, + reduce_bucket_size=50000000, + overlap_comm=False, + contiguous_gradients=True, + cpu_offload=False), +) diff --git a/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_deepspeed_stage3.py b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_deepspeed_stage3.py new file mode 100644 index 0000000..4bbd688 --- /dev/null +++ b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_deepspeed_stage3.py @@ -0,0 +1,60 @@ +_base_ = [ + "../_base_/models/stable_diffusion_xl.py", + "../_base_/datasets/pokemon_blip_xl.py", + "../_base_/schedules/stable_diffusion_xl_50e.py", + "../_base_/default_runtime.py", +] + +model = dict( + enable_xformers=True, + gradient_checkpointing=False) + +train_dataloader = dict(batch_size=8, num_workers=8) + +optim_wrapper = dict( + _delete_=True, + type="DeepSpeedOptimWrapper", + optimizer=dict( + type="FusedAdam", + lr=1e-5, + weight_decay=1e-2)) + +env_cfg = dict( + cudnn_benchmark=True, +) + +custom_hooks = [ + dict( + type="VisualizationHook", + prompt=["yoda pokemon"] * 4, + height=1024, + width=1024), + dict(type="SDCheckpointHook"), + dict(type="FastNormHook", fuse_main_ln=False, fuse_gn=False), +] + +runner_type = "FlexibleRunner" +strategy = dict( + type="DeepSpeedStrategy", + gradient_clipping=1.0, + gradient_accumulation_steps=4, + fp16=dict( + enabled=True, + fp16_master_weights_and_grads=False, + loss_scale=0, + loss_scale_window=500, + hysteresis=2, + min_loss_scale=1, + initial_scale_power=15, + ), + inputs_to_half=["inputs"], + zero_optimization=dict( + stage=3, + allgather_partitions=True, + reduce_scatter=True, + allgather_bucket_size=50000000, + reduce_bucket_size=50000000, + overlap_comm=True, + contiguous_gradients=True, + cpu_offload=False), +) diff --git a/pyproject.toml b/pyproject.toml index 8b8fc83..ca7ea50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ keywords = ["computer vision", "diffusion models"] [project.optional-dependencies] dev = ["pytest", "coverage"] -optional = ["ftfy", "bs4"] +optional = ["ftfy", "bs4", "deepspeed", "colossalai"] docs = [ "docutils==0.18.1", "modelindex",