Skip to content

Commit

Permalink
add threading for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
kai-tub committed Jun 21, 2024
1 parent 0dca694 commit 5da7ed4
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from natsort import natsorted
from more_itertools import chunked
from tqdm import tqdm
from concurrent import futures

log = structlog.get_logger()

Expand Down Expand Up @@ -97,9 +98,17 @@ def main(
for keys_chunk in tqdm(chunked(lmdb_keys, 512)):
with env.begin(write=True) as txn:
log.debug(f"First key of the chunk is: {keys_chunk[0]}")
for key in keys_chunk:
if not txn.put(str(key).encode(), s2_safetensor_generator(key, grouped[key]), overwrite=False):
sys.exit("Program is overwriting data in the DB! This should never happen!")
with futures.ThreadPoolExecutor(max_workers=64) as executor:
future_to_key = {executor.submit(writer, txn, key, grouped[key]): key for key in keys_chunk}
for future in futures.as_completed(future_to_key):
if not future.result():
sys.exit(f"Program is overwriting data {future_to_key[future]} in the DB! This should never happen!")

# for key in keys_chunk:
# if not txn.put(str(key).encode(), s2_safetensor_generator(key, grouped[key]), overwrite=False):

def writer(txn, key: Path, files):
return txn.put(str(key).encode(), s2_safetensor_generator(key, files))

if __name__ == "__main__":
typer.run(main)

0 comments on commit 5da7ed4

Please sign in to comment.