Skip to content

Commit

Permalink
Support csv for dreambooth dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku committed Dec 13, 2023
1 parent 96082d9 commit 5d37fcc
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
17 changes: 16 additions & 1 deletion diffengine/datasets/hf_dreambooth_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa: S311,RUF012
import copy
import hashlib
import os
import random
import shutil
from collections.abc import Sequence
Expand Down Expand Up @@ -45,6 +46,9 @@ class HFDreamBoothDataset(Dataset):
class_prompt (Optional[str]): The prompt to specify images in the same
class as provided instance images. Defaults to None.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
csv (str, optional): Image path csv file name when loading local
folder. If None, the dataset will be loaded from image folders.
Defaults to None.
cache_dir (str, optional): The directory where the downloaded datasets
will be stored.Defaults to None.
"""
Expand All @@ -65,8 +69,12 @@ def __init__(self,
class_image_config: dict | None = None,
class_prompt: str | None = None,
pipeline: Sequence = (),
csv: str | None = None,
cache_dir: str | None = None) -> None:

self.dataset_name = dataset
self.csv = csv

if class_image_config is None:
class_image_config = {
"model": "runwayml/stable-diffusion-v1-5",
Expand All @@ -77,7 +85,12 @@ def __init__(self,
}
if Path(dataset).exists():
# load local folder
self.dataset = load_dataset(dataset, cache_dir=cache_dir)["train"]
if csv is not None:
data_file = os.path.join(dataset, csv)
self.dataset = load_dataset(
"csv", data_files=data_file, cache_dir=cache_dir)["train"]
else:
self.dataset = load_dataset(dataset, cache_dir=cache_dir)["train"]
else: # noqa
# load huggingface online
if dataset_sub_dir is not None:
Expand Down Expand Up @@ -172,6 +185,8 @@ def __getitem__(self, idx: int) -> dict:
data_info = self.dataset[idx]
image = data_info[self.image_column]
if isinstance(image, str):
if self.csv is not None:
image = os.path.join(self.dataset_name, image)
image = Image.open(image)
image = image.convert("RGB")
result = {"img": image, "text": self.instance_prompt}
Expand Down
2 changes: 2 additions & 0 deletions diffengine/engine/hooks/peft_save_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
model.unet.save_pretrained(osp.join(ckpt_path, "unet"))
model_keys = ["unet"]
elif hasattr(model, "prior"):
# todo[okotaku]: Delete if bug is fixed in diffusers. # noqa
model.prior._internal_dict["_name_or_path"] = "prior" # noqa
model.prior.save_pretrained(osp.join(ckpt_path, "prior"))
model_keys = ["prior"]
elif hasattr(model, "transformer"):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_datasets/test_hf_dreambooth_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,16 @@ def test_dataset_from_local(self):
assert data["text"] == "a photo of sks dog"
assert isinstance(data["img"], Image.Image)
assert data["img"].width == 400

def test_dataset_from_local_with_csv(self):
dataset = HFDreamBoothDataset(
dataset="tests/testdata/dataset",
csv="metadata.csv",
image_column="file_name",
instance_prompt="a photo of sks dog")
assert len(dataset) == 1

data = dataset[0]
assert data["text"] == "a photo of sks dog"
assert isinstance(data["img"], Image.Image)
assert data["img"].width == 400

0 comments on commit 5d37fcc

Please sign in to comment.