From 5da7ed41a7c32ac4bf5f77bcb62d91d22c514cd7 Mon Sep 17 00:00:00 2001 From: Kai Norman Clasen Date: Fri, 21 Jun 2024 16:40:24 +0200 Subject: [PATCH] add threading for testing --- run.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/run.py b/run.py index f2d915a..1d7c783 100644 --- a/run.py +++ b/run.py @@ -12,6 +12,7 @@ from natsort import natsorted from more_itertools import chunked from tqdm import tqdm +from concurrent import futures log = structlog.get_logger() @@ -97,9 +98,17 @@ def main( for keys_chunk in tqdm(chunked(lmdb_keys, 512)): with env.begin(write=True) as txn: log.debug(f"First key of the chunk is: {keys_chunk[0]}") - for key in keys_chunk: - if not txn.put(str(key).encode(), s2_safetensor_generator(key, grouped[key]), overwrite=False): - sys.exit("Program is overwriting data in the DB! This should never happen!") + with futures.ThreadPoolExecutor(max_workers=64) as executor: + future_to_key = {executor.submit(writer, txn, key, grouped[key]): key for key in keys_chunk} + for future in futures.as_completed(future_to_key): + if not future.result(): + sys.exit(f"Program is overwriting data {future_to_key[future]} in the DB! This should never happen!") + + # for key in keys_chunk: + # if not txn.put(str(key).encode(), s2_safetensor_generator(key, grouped[key]), overwrite=False): + +def writer(txn, key: Path, files): + return txn.put(str(key).encode(), s2_safetensor_generator(key, files)) if __name__ == "__main__": typer.run(main)