From 50668897523b3b90b7a3da588c3dd635c7ca5916 Mon Sep 17 00:00:00 2001 From: Michal Charemza Date: Fri, 8 Nov 2024 07:58:33 +0000 Subject: [PATCH] feat: allow public buckets, i.e. without auth It turns out that public buckets, so ones that don't require any auth, should not have any authentication headers at all. So allowing the `get_credentials` parameter to be `None`, in which case signing of requests and passing authentication headers is skipped. As discussed at https://github.com/michalc/sqlite-s3-query/discussions/94 --- README.md | 20 ++++++++++++++++++ sqlite_s3_query.py | 26 +++++++++++++++-------- test.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 8da7133..32ad7dc 100644 --- a/README.md +++ b/README.md @@ -175,6 +175,26 @@ with \ print(row) ``` + +### Public Buckets + +For public buckets where credentials should not be passed, pass `None` as the `get_credentials` parameter. + +```python +query_my_db = partial(sqlite_s3_query, + url='https://my-public-bucket.s3.eu-west-2.amazonaws.com/my-db.sqlite', + get_credentials=None, +) + +with \ + query_my_db() as query, \ + query('SELECT * FROM my_table_2 WHERE my_col = ?', params=('my-value',)) as (columns, rows): + + for row in rows: + print(row) +``` + + ### HTTP Client The HTTP client can be changed by overriding the the default `get_http_client` parameter, which is shown below. diff --git a/sqlite_s3_query.py b/sqlite_s3_query.py index 4da69ef..f227580 100644 --- a/sqlite_s3_query.py +++ b/sqlite_s3_query.py @@ -74,6 +74,23 @@ def sqlite_s3_query_multi(url, get_credentials=lambda now: ( local = threading.local() local.pending_exception = None + def get_request_headers_for_private_buckets(method, params, headers, now): + region, access_key_id, secret_access_key, session_token = get_credentials(now) + to_auth_headers = headers + ( + (('x-amz-security-token', session_token),) if session_token is not None else \ + () + ) + return aws_sigv4_headers( + now, access_key_id, secret_access_key, region, method, to_auth_headers, params, + ) + + def get_request_headers_for_public_buckets(_, __, headers, ___): + return headers + + get_request_headers = \ + get_request_headers_for_private_buckets if get_credentials is not None else \ + get_request_headers_for_public_buckets + def set_pending_exception(exception): local.pending_exception = exception @@ -98,14 +115,7 @@ def run_with_db(db, func, *args): @contextmanager def make_auth_request(http_client, method, params, headers): now = datetime.utcnow() - region, access_key_id, secret_access_key, session_token = get_credentials(now) - to_auth_headers = headers + ( - (('x-amz-security-token', session_token),) if session_token is not None else \ - () - ) - request_headers = aws_sigv4_headers( - now, access_key_id, secret_access_key, region, method, to_auth_headers, params, - ) + request_headers = get_request_headers(method, params, headers, now) url = f'{scheme}://{netloc}{path}' with http_client.stream(method, url, params=params, headers=request_headers) as response: response.raise_for_status() diff --git a/test.py b/test.py index 17e605a..37210e0 100644 --- a/test.py +++ b/test.py @@ -148,6 +148,27 @@ def test_select_with_named_params(self): self.assertEqual(rows, [(500,)]) + def test_select_with_named_params_public_bucket(self): + create_bucket('my-public-bucket') + disable_auth('my-public-bucket') + with get_db([ + ("CREATE TABLE my_table (my_col_a text, my_col_b text);", ()) + ] + [ + ("INSERT INTO my_table VALUES " + ','.join(["('some-text-a', 'some-text-b')"] * 500), ()), + ("INSERT INTO my_table VALUES " + ','.join(["('some-text-c', 'some-text-d')"] * 100), ()), + ]) as db: + put_object_with_versioning('my-public-bucket', 'my.db', db) + + with sqlite_s3_query( + 'http://localhost:9000/my-public-bucket/my.db', + get_credentials=None, + get_libsqlite3=get_libsqlite3 + ) as query: + with query('SELECT COUNT(*) FROM my_table WHERE my_col_a = :first', named_params=((':first', 'some-text-a'),)) as (columns, rows): + rows = list(rows) + + self.assertEqual(rows, [(500,)]) + def test_select_large(self): empty = (bytes(4050),) @@ -840,6 +861,36 @@ def enable_versioning(bucket): response = httpx.put(url, content=content, headers=headers) response.raise_for_status() +def disable_auth(bucket): + content = f''' + {{ + "Version": "2012-10-17", + "Statement": [ + {{ + "Sid": "Stmt1405592139000", + "Effect": "Allow", + "Principal": "*", + "Action": [ + "s3:GetObject", + "s3:GetObjectVersion" + ], + "Resource": [ + "arn:aws:s3:::{bucket}/*" + ] + }} + ] + }} + '''.encode() + url = f'http://127.0.0.1:9000/{bucket}/?policy' + body_hash = hashlib.sha256(content).hexdigest() + parsed_url = urllib.parse.urlsplit(url) + + headers = aws_sigv4_headers( + 'AKIAIOSFODNN7EXAMPLE', 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY', + (), 's3', 'us-east-1', parsed_url.netloc, 'PUT', parsed_url.path, (('policy', ''),), body_hash, + ) + response = httpx.put(url, content=content, headers=headers) + response.raise_for_status() def aws_sigv4_headers(access_key_id, secret_access_key, pre_auth_headers, service, region, host, method, path, params, body_hash):