diff --git a/integration_tests/python_tests.ipynb b/integration_tests/python_tests.ipynb index 5bd21a6..1002b01 100644 --- a/integration_tests/python_tests.ipynb +++ b/integration_tests/python_tests.ipynb @@ -86,11 +86,273 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "1618f995-d4f5-4a29-b868-d7958ef04c8a", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "import lmdb\n", + "import rasterio\n", + "import safetensors\n", + "import numpy as np\n", + "from pathlib import Path\n", + "from safetensors.numpy import deserialize, load_file, load\n", + "\n", + "# def read_single_band_raster(path):\n", + "# with rasterio.open(path) as r:\n", + "# return r.read(1)\n", + "\n", + "# p = Path(\"tiffs/BigEarthNet\")\n", + "# source_data = {file: read_single_band_raster(file) for file in p.glob(\"**/*.tif*\")}\n", + "\n", + "# code to create the directory\n", + "# ./result/bin/encoder --bigearthnet-s1-root tiffs/BigEarthNet/S1/ --bigearthnet-s2-root tiffs/BigEarthNet/S2/ artifacts/\n", + "env = lmdb.open(\"../artifacts\", readonly=True)\n", + "\n", + "with env.begin(write=False) as txn:\n", + " cur = txn.cursor()\n", + " decoded_lmdb_data = {k.decode(\"utf-8\"): load(v) for (k, v) in cur}\n", + "\n", + "# The encoded data is nested inside of another safetensor dictionary, where the inner keys are derived from the band suffix\n", + "# decoded_values = [v for outer_v in decoded_lmdb_data.values() for v in outer_v.values()]\n", + "\n", + "# # Simply check if the data remains identical, as this is the only _true_ thing I care about from the Python viewpoint\n", + "# # If the keys/order or anything else is wrong, it isn't part of the integration test but should be handled separately as a unit test!\n", + "# for (source_key, source_value) in source_data.items():\n", + "# assert any(np.array_equal(source_value, decoded_value) for decoded_value in decoded_values), f\"Couldn't find data in the LMDB database that matches the data from: {source_key}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "3d819f7d-4c0c-41f2-b874-b2a6c8d3d0a5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "50.2 µs ± 651 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "for i in range(1, 225):\n", + " decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)]" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "14bce7df-6aae-4c2d-92ca-f2f921a93268", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.74 µs ± 29 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "for i in range(1, 22):\n", + " decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)]" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "26d974c0-531e-419e-9f69-656db858292f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "995 µs ± 9.96 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "np.stack([decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)] for i in range(1, 225)], axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "3383c0a5-2130-4412-9ff7-94666b7a5f72", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "71.4 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "np.stack([decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)] for i in range(1, 22)], axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a427582-9cc8-4493-b909-5791160da3e8", + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "a = np.zeros" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "7759f06d-75ef-40ad-a6b9-afed13d5a352", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(4, 128, 128)" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "SUPPORTED_BANDS = list(i for i in range(5))\n", + "np.stack([decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)] for i in range(1, 22) if i in SUPPORTED_BANDS], axis=0).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "dec80ead-1c04-42d7-99de-92df4c3a7a76", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5.96 ms ± 185 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "a = np.zeros((224, 128, 128))\n", + "for i in range(1, 224):\n", + " a[i-1] = decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)]" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "3341546c-0fe8-4f52-8b66-6190db7002e8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.61 ms ± 76.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "minimum_value = 0\n", + "maximum_value = 10000\n", + "\n", + "clipped = np.stack([decoded_lmdb_data['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438'][str(i)] for i in range(1, 225)], axis=0).clip(min=minimum_value, max=maximum_value)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "a5a5702d-377b-4218-b40b-99339ecc6946", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7.15 ms ± 436 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "out_data = (clipped - minimum_value) / (maximum_value - minimum_value)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "b256f6b1-10c0-4219-a00e-6a45ec74499b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3.21 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "out_dataf = out_data.astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "ae7be06a-3d12-44af-aa82-c233d65824ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10.1 ms ± 403 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "# astype without explicit intermediate value is just as fast as with intermediate value\n", + "out_data = ((clipped - minimum_value) / (maximum_value - minimum_value)).astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37c8e0f7-03d5-4ed4-bd44-99d5415447ad", + "metadata": {}, + "outputs": [], + "source": [ + "# for a single patch it takes around 10ms per patch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49f82b44-6cad-44b4-963a-9d39897c2f37", + "metadata": {}, + "outputs": [], + "source": [ + "# 0.72 batches / sek bei Martin for entire training" + ] } ], "metadata": { @@ -109,7 +371,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/integration_tests/tiffs/hyspecnet-11k/ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438-SPECTRAL_IMAGE.TIF b/integration_tests/tiffs/hyspecnet-11k/ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438-SPECTRAL_IMAGE.TIF new file mode 100644 index 0000000..4565d9c Binary files /dev/null and b/integration_tests/tiffs/hyspecnet-11k/ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438-SPECTRAL_IMAGE.TIF differ diff --git a/src/main.rs b/src/main.rs index 3893e3c..e8385bb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -60,14 +60,17 @@ struct DataKeyPair { enum TypedDataKeyPair { BigEarthNetS1(DataKeyPair), BigEarthNetS2(DataKeyPair), - // ENMAP + HyspecNet(DataKeyPair), } +const N_HYSPECNET_BANDS: isize = 224; + impl TypedDataKeyPair { fn get_safetensors_key(&self) -> &String { match self { Self::BigEarthNetS1(d) => &d.safetensors_key, Self::BigEarthNetS2(d) => &d.safetensors_key, + Self::HyspecNet(d) => &d.safetensors_key, } } } @@ -76,6 +79,7 @@ impl TypedDataKeyPair { enum Satellite { Sentinel1, Sentinel2, + Enmap, } /// Encoder that converts TIFF files into `safetensors` values and embeds them in an LMDB database. @@ -123,6 +127,10 @@ struct Cli { /// Path to the BigEarthNet-S2 root directory. #[arg(long, value_name = "ROOT_DIR")] bigearthnet_s2_root: Option, + + /// Path to the HyspecNet root directory. + #[arg(long, value_name = "ROOT_DIR")] + hyspecnet_root: Option, } fn main() -> anyhow::Result<()> { @@ -130,19 +138,27 @@ fn main() -> anyhow::Result<()> { let mut v = Vec::new(); if let Some(bigearthnet_s1_root) = cli.bigearthnet_s1_root { println!("Starting to process BigEarthNet-S1"); - v.push(generate_grouped_files_from_bigearthnet( + v.push(generate_grouped_files( bigearthnet_s1_root.to_str().unwrap(), Satellite::Sentinel1, )); } if let Some(bigearthnet_s2_root) = cli.bigearthnet_s2_root { println!("Starting to process BigEarthNet-S2"); - v.push(generate_grouped_files_from_bigearthnet( + v.push(generate_grouped_files( bigearthnet_s2_root.to_str().unwrap(), Satellite::Sentinel2, )); } + if let Some(hyspecnet_root) = cli.hyspecnet_root { + println!("Starting to process BigEarthNet-S2"); + v.push(generate_grouped_files( + hyspecnet_root.to_str().unwrap(), + Satellite::Enmap, + )); + } + if v.len() == 0 { println!("No dataset selected! Nothing will be generated!"); return Ok(()); @@ -264,14 +280,14 @@ fn bigearthnet_s2_ordering(a: &str, b: &str) -> Ordering { } } -/// TODO: Update the outdated description! -/// Given a regular expression `pattern` with the keys `group` and `safetensorsKey` -/// loop over all `paths` (files), extract the file stems and apply the regular expression. -/// The extracted `group` will become the key of the returning HashMap +/// Given a list of file `paths` loop over each file, extract the file stems and +/// generate a grouped_files. +/// The function applies a pre-defined regular expression (selected via `satellite` enum) +/// where a `group` will be extracted as the key of the returning hashmap /// and the associated value will be pushed to a vector, where each value is a `DataKeyPair` /// with the `path` set to the considered `path` and the `safetensors_key` to the matched /// `saftensorsKey` from the regular expression. -fn generate_grouped_files_from_bigearthnet_paths( +fn generate_grouped_files_from_paths( paths: Vec, satellite: Satellite, ) -> HashMap> { @@ -281,6 +297,8 @@ fn generate_grouped_files_from_bigearthnet_paths( let pattern_str = match satellite { Satellite::Sentinel2 => r"(?.*)_(?B[0-9A]+)$", Satellite::Sentinel1 => r"(?.*)_(?V[VH])$", + // ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438-SPECTRAL_IMAGE.TIF + Satellite::Enmap => r"(?.*)-(?SPECTRAL_IMAGE)$", }; let pattern = Regex::new(&pattern_str).unwrap(); // FUTURE: potentially think about parallel access as NFS storage could benefit from it @@ -299,10 +317,14 @@ fn generate_grouped_files_from_bigearthnet_paths( .push(match satellite { Satellite::Sentinel1 => TypedDataKeyPair::BigEarthNetS1(datakeypair), Satellite::Sentinel2 => TypedDataKeyPair::BigEarthNetS2(datakeypair), + Satellite::Enmap => TypedDataKeyPair::HyspecNet(datakeypair), }) } None => { - warn!("Found a tiff file that doesn't match the expected regular expression: \n{}\nThis might indicate issues with the dataset directory!", p.to_str().unwrap_or("")); + match satellite { + Satellite::Enmap => {}, + _ => warn!("Found a tiff file that doesn't match the expected regular expression: \n{}\nThis might indicate issues with the dataset directory!", p.to_str().unwrap_or("")), + } } } } @@ -310,12 +332,12 @@ fn generate_grouped_files_from_bigearthnet_paths( grouped_files } -fn generate_grouped_files_from_bigearthnet( - root_ben_dir: &str, +fn generate_grouped_files( + root_dir: &str, satellite: Satellite, ) -> HashMap> { - let paths = recursively_find_tiffs(root_ben_dir); - let mut grouped_files = generate_grouped_files_from_bigearthnet_paths(paths, satellite); + let paths = recursively_find_tiffs(root_dir); + let mut grouped_files = generate_grouped_files_from_paths(paths, satellite); for vals in grouped_files.values_mut() { vals.sort_by(|a, b| match (a, b) { // Only support sorting on S2 only keys here @@ -325,6 +347,10 @@ fn generate_grouped_files_from_bigearthnet( (TypedDataKeyPair::BigEarthNetS1(a), TypedDataKeyPair::BigEarthNetS1(b)) => { bigearthnet_s1_ordering(&a.safetensors_key, &b.safetensors_key) } + (TypedDataKeyPair::HyspecNet(a), TypedDataKeyPair::HyspecNet(b)) => { + // hyspecnet_ordering(&a.safetensors_key, &b.safetensors_key) + a.safetensors_key.cmp(&b.safetensors_key) + } _ => panic!("Unsupported ordering operation!"), }) } @@ -381,14 +407,13 @@ fn lmdb_writer(db_path: &Path, grouped_files: &HashMap { U16(Array), + I16(Array), F32(Array), } -// TODO: The same code should also work for the enmap dataset, as I WANT -// to flatten the bands into individual keys, so the index must be configurable here fn mk_bigearthnet_safetensor( datakeypair: &DataKeyPair, -) -> SupportedWrapper> { +) -> (String, SupportedWrapper>) { let dataset = Dataset::open(datakeypair.path.clone()).expect("Current file should have read access!"); let band1 = dataset @@ -400,7 +425,7 @@ fn mk_bigearthnet_safetensor( // `window_size` is the amount to read -> We will always read everything! let window_size = band1.size(); // assert_eq!(band1.band_type(), GdalDataType::UInt16); - match band1.band_type() { + let arr = match band1.band_type() { GdalDataType::UInt16 => SupportedWrapper::U16( band1 .read_as_array::(window, window_size, window_size, None) @@ -412,7 +437,38 @@ fn mk_bigearthnet_safetensor( .expect("File should open correctly. Report bug!"), ), _ => panic!("Unsupported data type detected!"), - } + }; + // TODO: Understand if `arr` is copied or not! + (datakeypair.safetensors_key.clone(), arr) +} + +// Function cannot consume the datakeypair +// name mapping could also be implemented by caller +fn mk_hyspecnet_safetensor( + datakeypair: &DataKeyPair, + index: isize, +) -> (String, SupportedWrapper>) { + let dataset = + Dataset::open(datakeypair.path.clone()).expect("Current file should have read access!"); + let band1 = dataset + .rasterband(index) + .expect("Tiff files should contain at least one band!"); + // tuples (x, y) are in (cols, rows) order + // `window` is the (x, y) coordinate of the upper left corner of the region to read + let window = (0, 0); + // `window_size` is the amount to read -> We will always read everything! + let window_size = band1.size(); + // assert_eq!(band1.band_type(), GdalDataType::UInt16); + let arr = match band1.band_type() { + GdalDataType::Int16 => SupportedWrapper::I16( + band1 + .read_as_array::(window, window_size, window_size, None) + .expect("File should open correctly. Report bug!"), + ), + _ => panic!("Unsupported data type detected!"), + }; + // (datakeypair.safetensors_key.clone(), arr) + (index.to_string(), arr) } /// Given a `DataKeyPair` `d` vector, iterate through all elements @@ -421,32 +477,15 @@ fn mk_bigearthnet_safetensor( /// Then construct the given safetensor data and return the resulting /// data vector. fn mk_safetensors(pairs: &Vec) -> anyhow::Result> { - let it = pairs.into_iter().map(|e| match e { - TypedDataKeyPair::BigEarthNetS1(e) => mk_bigearthnet_safetensor(&e), - TypedDataKeyPair::BigEarthNetS2(e) => mk_bigearthnet_safetensor(&e), - // _ => panic!("Not implemented yet!"), - // For the keys, I need to apply a slightly different logic for enmap - // So depending on the value of `e` below, I might want to... - // fuck... - // No, the mk_bigearthnet_safetensor code here also needs to directly return - // the key, as I need to repeat the execution for every possible band - // and merge the current e.safetensor_key with the band index - // I mainly need to ensure that the bands/arrays are read iteratively and not - // materialized in the beginning - // The solution is probably something like flat_map, where every `e` value - // returns an iterator of the desired type - // In the future, I should also avoid this 'value' based matching - // I should keep the vectors distinct and only merge the keys once for checking - // Then there is no need to continuously jump around on every element + let it = pairs.into_iter().flat_map(|e| match e { + TypedDataKeyPair::BigEarthNetS1(e) => vec![mk_bigearthnet_safetensor(&e)], + TypedDataKeyPair::BigEarthNetS2(e) => vec![mk_bigearthnet_safetensor(&e)], + // GDAL is 1-indexed! + TypedDataKeyPair::HyspecNet(e) => (1..(N_HYSPECNET_BANDS + 1)) + .map(|i| mk_hyspecnet_safetensor(&e, i)) + .collect::>)>>(), }); - - Ok(serialize( - pairs - .iter() - .map(|e| e.get_safetensors_key().clone()) - .zip(it), - &None, - )?) + Ok(serialize(it, &None)?) } // Code from GitHub issue: @@ -468,6 +507,16 @@ impl SupportedWrapper { }; new_slice } + SupportedWrapper::I16(arr) => { + let slice = arr.as_slice().expect("Non-contiguous memory for tensor!"); + let num_bytes = std::mem::size_of::(); + let new_slice: &[u8] = unsafe { + // len is the number of elements not the number of bytes! + // but as we are using u8 it is effectively the same + std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len() * num_bytes) + }; + new_slice + } SupportedWrapper::F32(arr) => { let slice = arr.as_slice().expect("Non-contiguous memory for tensor!"); let num_bytes = std::mem::size_of::(); @@ -487,6 +536,7 @@ impl View for SupportedWrapper { match self { SupportedWrapper::U16(_) => Dtype::U16, SupportedWrapper::F32(_) => Dtype::F32, + SupportedWrapper::I16(_) => Dtype::I16, } } @@ -494,6 +544,7 @@ impl View for SupportedWrapper { match self { SupportedWrapper::U16(arr) => arr.shape(), SupportedWrapper::F32(arr) => arr.shape(), + SupportedWrapper::I16(arr) => arr.shape(), } }