Skip to content

Commit

Permalink
Merge pull request #20 from dlr-eoc/onnx_session_options
Browse files Browse the repository at this point in the history
Onnx session options
  • Loading branch information
MWieland authored Dec 5, 2023
2 parents 17070b7 + 97216f4 commit 42b294c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 18 deletions.
10 changes: 7 additions & 3 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
Changelog
=========

[0.2.0] (2023-05-24)
[0.2.0] (2023-12-04)
--------------------
Fixed
Added
*******
- support onnxruntime session options

Changed
*******
- new models in manifest
- band naming scheme to match STAC standard

[0.1.9] (2023-05-12)
--------------------
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://black.readthedocs.io/en/stable/)
[![DOI](https://zenodo.org/badge/328616234.svg)](https://zenodo.org/badge/latestdoi/328616234)

UKIS Cloud Shadow MASK (ukis-csmask) package masks clouds and cloud shadows in Sentinel-2, Landsat-9, Landsat-8, Landsat-7 and Landsat-5 images. Masking is performed with a pre-trained convolution neural network. It is fast and works directly on Level-1C data (no atmospheric correction required). Images just need to be in Top Of Atmosphere (TOA) reflectance and include at least the "Blue", "Green", "Red" and "NIR" spectral bands. Best performance (in terms of accuracy and speed) is achieved when images also include "SWIR1" and "SWIR2" spectral bands and are resampled to approximately 30 m spatial resolution.
UKIS Cloud Shadow MASK (ukis-csmask) package masks clouds and cloud shadows in Sentinel-2, Landsat-9, Landsat-8, Landsat-7 and Landsat-5 images. Masking is performed with a pre-trained convolution neural network. It is fast and works directly on Level-1C data (no atmospheric correction required). Images just need to be in Top Of Atmosphere (TOA) reflectance and include at least the "blue", "green", "red" and "nir" spectral bands. Best performance (in terms of accuracy and speed) is achieved when images also include "swir16" and "swir22" spectral bands and are resampled to approximately 30 m spatial resolution.

This [publication](https://doi.org/10.1016/j.rse.2019.05.022) provides further insight into the underlying algorithm and compares it to the widely used [Fmask](http://www.pythonfmask.org/en/latest/) algorithm across a heterogeneous test dataset.

Expand Down Expand Up @@ -45,7 +45,7 @@ img.warp(
# make sure to use these six spectral bands to get best performance
csmask = CSmask(
img=img.arr,
band_order=["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"],
band_order=["blue", "green", "red", "nir", "swir16", "swir22"],
nodata_value=0,
)

Expand All @@ -63,6 +63,7 @@ csmask_valid = Image(csmask.valid, transform=img.dataset.transform, crs=img.data
csmask_csm.write_to_file("sentinel2_csm.tif", dtype="uint8", compress="PACKBITS")
csmask_valid.write_to_file("sentinel2_valid.tif", dtype="uint8", compress="PACKBITS", kwargs={"nbits":2})
````

## Example (Landsat 8)
Here's a similar example based on Landsat 8.

Expand Down Expand Up @@ -101,7 +102,7 @@ img = Image(data=L8_bands, crs = meta['crs'], transform = meta['transform'], dim
img.dn2toa(
platform=Platform.Landsat8,
mtl_file=mtl_file,
wavelengths = ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]
wavelengths = ["blue", "green", "red", "nir", "swir16", "swir22"]
)
# >> proceed by analogy with Sentinel 2 example
````
Expand Down
38 changes: 26 additions & 12 deletions ukis_csmask/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(
band_order,
nodata_value=None,
invalid_buffer=4,
intra_op_num_threads=0,
inter_op_num_threads=0,
):
"""
:param img: Input satellite image of shape (rows, cols, bands). (ndarray).
Expand All @@ -35,9 +37,11 @@ def __init__(
For better performance requires image bands to include "Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2" (runs 6 band model).
For better performance requires image bands to be in approximately 30 m resolution.
:param band_order: Image band order. (list of string).
>>> band_order = ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]
>>> band_order = ["blue", "green", "red", "nir", "swir16", "swir22"]
:param nodata_value: Additional nodata value that will be added to valid mask. (num).
:param invalid_buffer: Number of pixels that should be buffered around invalid areas.
:param invalid_buffer: Number of pixels that should be buffered around invalid areas. (int).
:param intra_op_num_threads: Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose. (int).
:param inter_op_num_threads: Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose. (int).
"""
# consistency checks on input image
if isinstance(img, np.ndarray) is False:
Expand All @@ -63,20 +67,25 @@ def __init__(
if band_order is None:
raise TypeError("band_order cannot be None")

if all(elem in band_order for elem in ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]):
target_band_order = ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]
# ensure backwards compatibility with old band naming scheme
band_order = [b.lower() for b in band_order]
band_order = [b.replace("swir1", "swir16") for b in band_order]
band_order = [b.replace("swir2", "swir22") for b in band_order]

if all(elem in band_order for elem in ["blue", "green", "red", "nir", "swir16", "swir22"]):
target_band_order = ["blue", "green", "red", "nir", "swir16", "swir22"]
band_mean = [0.19312, 0.18659, 0.18899, 0.30362, 0.23085, 0.16216]
band_std = [0.16431, 0.16762, 0.18230, 0.17409, 0.16020, 0.14164]
model_file = str(Path(__file__).parent) + "/model_6b.onnx"
elif all(elem in band_order for elem in ["Blue", "Green", "Red", "NIR"]):
target_band_order = ["Blue", "Green", "Red", "NIR"]
elif all(elem in band_order for elem in ["blue", "green", "red", "nir"]):
target_band_order = ["blue", "green", "red", "nir"]
band_mean = [0.19312, 0.18659, 0.18899, 0.30362]
band_std = [0.16431, 0.16762, 0.18230, 0.17409]
model_file = str(Path(__file__).parent) + "/model_4b.onnx"
else:
raise TypeError(
f"band_order must contain at least 'Blue', 'Green', 'Red', 'NIR' "
f"and for better performance also 'SWIR1' and 'SWIR2'"
f"band_order must contain at least 'blue', 'green', 'red', 'nir' "
f"and for better performance also 'swir16' and 'swir22'"
)

# rearrange image bands to match target_band_order
Expand All @@ -85,6 +94,14 @@ def __init__(
)
img = img[:, :, idx]

# start onnx inference session and load model
so = onnxruntime.SessionOptions()
so.intra_op_num_threads = intra_op_num_threads
so.inter_op_num_threads = inter_op_num_threads
self.sess = onnxruntime.InferenceSession(
model_file, sess_options=so, providers=onnxruntime.get_available_providers()
)

self.img = img
self.band_order = band_order
self.band_mean = band_mean
Expand All @@ -106,11 +123,8 @@ def _csm(self):
x -= self.band_mean
x /= self.band_std

# start onnx inference session and load model
sess = onnxruntime.InferenceSession(self.model_file, providers=onnxruntime.get_available_providers())

# predict on array tiles
y_prob = [sess.run(None, {"input_1": tile[np.newaxis, :]}) for n, tile in enumerate(list(x))]
y_prob = [self.sess.run(None, {"input_1": tile[np.newaxis, :]}) for n, tile in enumerate(list(x))]
y_prob = np.concatenate(y_prob)[:, 0, :, :, :]

# untile probabilities with smooth blending
Expand Down

0 comments on commit 42b294c

Please sign in to comment.