Skip to content

Commit

Permalink
Initial support for hyspecnet
Browse files Browse the repository at this point in the history
  • Loading branch information
kai-tub committed Jun 10, 2024
1 parent 9faa6d0 commit dfbd8e1
Show file tree
Hide file tree
Showing 3 changed files with 359 additions and 46 deletions.
268 changes: 265 additions & 3 deletions integration_tests/python_tests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,273 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "1618f995-d4f5-4a29-b868-d7958ef04c8a",
"metadata": {},
"outputs": [],
"source": []
"source": [
"import lmdb\n",
"import rasterio\n",
"import safetensors\n",
"import numpy as np\n",
"from pathlib import Path\n",
"from safetensors.numpy import deserialize, load_file, load\n",
"\n",
"# def read_single_band_raster(path):\n",
"# with rasterio.open(path) as r:\n",
"# return r.read(1)\n",
"\n",
"# p = Path(\"tiffs/BigEarthNet\")\n",
"# source_data = {file: read_single_band_raster(file) for file in p.glob(\"**/*.tif*\")}\n",
"\n",
"# code to create the directory\n",
"# ./result/bin/encoder --bigearthnet-s1-root tiffs/BigEarthNet/S1/ --bigearthnet-s2-root tiffs/BigEarthNet/S2/ artifacts/\n",
"env = lmdb.open(\"../artifacts\", readonly=True)\n",
"\n",
"with env.begin(write=False) as txn:\n",
" cur = txn.cursor()\n",
" decoded_lmdb_data = {k.decode(\"utf-8\"): load(v) for (k, v) in cur}\n",
"\n",
"# The encoded data is nested inside of another safetensor dictionary, where the inner keys are derived from the band suffix\n",
"# decoded_values = [v for outer_v in decoded_lmdb_data.values() for v in outer_v.values()]\n",
"\n",
"# # Simply check if the data remains identical, as this is the only _true_ thing I care about from the Python viewpoint\n",
"# # If the keys/order or anything else is wrong, it isn't part of the integration test but should be handled separately as a unit test!\n",
"# for (source_key, source_value) in source_data.items():\n",
"# assert any(np.array_equal(source_value, decoded_value) for decoded_value in decoded_values), f\"Couldn't find data in the LMDB database that matches the data from: {source_key}\""
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "3d819f7d-4c0c-41f2-b874-b2a6c8d3d0a5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"50.2 µs ± 651 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"for i in range(1, 225):\n",
" decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)]"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "14bce7df-6aae-4c2d-92ca-f2f921a93268",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.74 µs ± 29 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"for i in range(1, 22):\n",
" decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)]"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "26d974c0-531e-419e-9f69-656db858292f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"995 µs ± 9.96 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"np.stack([decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)] for i in range(1, 225)], axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "3383c0a5-2130-4412-9ff7-94666b7a5f72",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"71.4 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"np.stack([decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)] for i in range(1, 22)], axis=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2a427582-9cc8-4493-b909-5791160da3e8",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"a = np.zeros"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "7759f06d-75ef-40ad-a6b9-afed13d5a352",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(4, 128, 128)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"SUPPORTED_BANDS = list(i for i in range(5))\n",
"np.stack([decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)] for i in range(1, 22) if i in SUPPORTED_BANDS], axis=0).shape"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "dec80ead-1c04-42d7-99de-92df4c3a7a76",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5.96 ms ± 185 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"a = np.zeros((224, 128, 128))\n",
"for i in range(1, 224):\n",
" a[i-1] = decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)]"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "3341546c-0fe8-4f52-8b66-6190db7002e8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.61 ms ± 76.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"minimum_value = 0\n",
"maximum_value = 10000\n",
"\n",
"clipped = np.stack([decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)] for i in range(1, 225)], axis=0).clip(min=minimum_value, max=maximum_value)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "a5a5702d-377b-4218-b40b-99339ecc6946",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7.15 ms ± 436 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"out_data = (clipped - minimum_value) / (maximum_value - minimum_value)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "b256f6b1-10c0-4219-a00e-6a45ec74499b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.21 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"out_dataf = out_data.astype(np.float32)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "ae7be06a-3d12-44af-aa82-c233d65824ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10.1 ms ± 403 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"# astype without explicit intermediate value is just as fast as with intermediate value\n",
"out_data = ((clipped - minimum_value) / (maximum_value - minimum_value)).astype(np.float32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "37c8e0f7-03d5-4ed4-bd44-99d5415447ad",
"metadata": {},
"outputs": [],
"source": [
"# for a single patch it takes around 10ms per patch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49f82b44-6cad-44b4-963a-9d39897c2f37",
"metadata": {},
"outputs": [],
"source": [
"# 0.72 batches / sek bei Martin for entire training"
]
}
],
"metadata": {
Expand All @@ -109,7 +371,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
Binary file not shown.
Loading

0 comments on commit dfbd8e1

Please sign in to comment.