Skip to content

Commit

Permalink
finish chromadb_adapter.py venomx update + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
iQuxLE committed Oct 21, 2024
1 parent cbfb0d7 commit 1676367
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/curategpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"main",
]

from venomx.model.venomx import Dataset, Index, Model, ModelInputMethod
from venomx.model.venomx import Dataset, Model

from curategpt.store.metadata import Metadata

Expand Down
29 changes: 19 additions & 10 deletions src/curategpt/store/chromadb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def set_collection_metadata(
:param metadata:
:return:
"""
# TODO: (for carlo) just call update
chromadb_metadata = metadata.serialize_venomx_metadata_for_adapter(self.name)
self.client.get_or_create_collection(
name=collection_name,
Expand All @@ -345,24 +346,27 @@ def update_collection_metadata(self, collection_name: str, **kwargs) -> Metadata
metadata = self.collection_metadata(collection_name=collection_name)

if metadata is not None:
scalar_updates = {k: v for k, v in kwargs.items() if k != "venomx"}
scalar_updates = {k: v for k, v in kwargs.items() if k != "venomx"} # any additional model param or object type
metadata = metadata.model_copy(update=scalar_updates)
prev_model = metadata.venomx.embedding_model.name
if prev_model and metadata.model != prev_model:
if self.client.get_or_create_collection(name=collection_name).count() > 0:
raise ValueError(f"Cannot change model from {prev_model} to {metadata.model}")

# assign venomx to metadata object
if "venomx" in kwargs and kwargs.get("venomx") is not None:
# assign venomx to metadata object
metadata.venomx = kwargs.get("venomx")
metadata = Metadata(venomx=kwargs.get("venomx"))
else:
metadata = Metadata(
venomx=kwargs.get("venomx"),
# hnsw_space=kwargs.get("hnsw_space", "cosine"),
# object_type=kwargs.get("object_type"),
hnsw_space=kwargs.get("hnsw_space", "cosine"),
object_type=kwargs.get("object_type"),
)

# Ensure 'venomx.id' matches 'collection_name' if venomx is provided
if metadata.venomx:
if metadata:
if metadata.venomx.id != collection_name:
print(f"venomx.id: {metadata.venomx.id} must match collection_name {collection_name}")
metadata.venomx.id = collection_name
raise ValueError(f"venomx.id: {metadata.venomx.id} must match collection_name {collection_name}")

# metadata.hnsw_space = "cosine"
chromadb_metadata = metadata.serialize_venomx_metadata_for_adapter(self.name)
Expand Down Expand Up @@ -426,8 +430,11 @@ def _search(
# want to accidentally set it
collection = client.get_collection(name=self._get_collection(collection))
metadata = collection.metadata
# deserialize _venomx str to venomx dict and put in Metadata model
metadata = json.loads(metadata["_venomx"])
metadata = Metadata(venomx=Index(**metadata))
collection = client.get_collection(
name=collection.name, embedding_function=self._embedding_function(metadata["model"])
name=collection.name, embedding_function=self._embedding_function(metadata.venomx.embedding_model.name)
)
logger.debug(f"Collection metadata: {metadata}")
if text:
Expand Down Expand Up @@ -522,7 +529,9 @@ def diversified_search(
)
collection_obj = self._get_collection_object(collection)
metadata = collection_obj.metadata
ef = self._embedding_function(metadata["model"])
metadata = json.loads(metadata["_venomx"])
metadata = Metadata(venomx=Index(**metadata))
ef = self._embedding_function(metadata.venomx.embedding_model.name)
if len(text) > self.default_max_document_length:
logger.warning(
f"Text too long ({len(text)}), truncating to {self.default_max_document_length}"
Expand Down
40 changes: 7 additions & 33 deletions tests/store/test_chromadb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,6 @@ def simple_schema_manager() -> SchemaProxy:
)
return SchemaProxy(sb.schema)

# not finisished
@pytest.mark.skip
def test_setting_collection_metadata(example_texts):
db = ChromaDBAdapter(str(OUTPUT_CHROMA_DB_PATH))
db.client.reset()
assert db.list_collection_names() == []
collection = "test"
objs = terms_to_objects(example_texts)
db.insert(objs, collection=collection)
md = db.collection_metadata(collection)
md.venomx.id = "test collection"
md.venomx.embedding_model.name = "openai:"
db.set_collection_metadata(collection, md)
assert md.venomx.id == "test collection"
assert db.collection_metadata(collection).venomx.id == "test collection"
assert db.collection_metadata(collection).venomx.embedding_model.name == "openai:"

# not finisished
@pytest.mark.skip
def test_store(simple_schema_manager, example_texts):
db = ChromaDBAdapter(str(OUTPUT_CHROMA_DB_PATH))
db.schema_proxy = simple_schema_manager
Expand All @@ -73,24 +54,19 @@ def test_store(simple_schema_manager, example_texts):
db.insert(objs, collection=collection)
md = db.collection_metadata(collection)
md.venomx.id = "test collection"
md.venomx.embedding_model.name = "openai:"
db.set_collection_metadata(collection, md)
assert md.venomx.id == "test collection"
assert db.collection_metadata(collection).venomx.id == "test collection"
assert db.collection_metadata(collection).venomx.embedding_model.name == "openai:"


db2 = ChromaDBAdapter(str(OUTPUT_CHROMA_DB_PATH))
assert db2.collection_metadata(collection).description == "test collection"
assert db2.collection_metadata(collection).venomx.id == "test collection"
assert db.list_collection_names() == ["test"]
results = list(db.search("fox", collection=collection))
print(results)
# print(results)
for obj in objs:
print(f"QUERYING: {obj}")
for match in db.matches(obj, collection=collection):
print(f" - MATCH: {match}")
db.update(objs, collection=collection)
assert db.collection_metadata(collection).description == "test collection"
canines = list(db.find(where={"text": {"$eq": "canine"}}, collection=collection))
print(f"CANINES: {canines}")
long_words = list(db.find(where={"wordlen": {"$gt": 12}}, collection=collection))
Expand Down Expand Up @@ -148,16 +124,14 @@ def test_embedding_function(simple_schema_manager, example_texts):
db.insert(objs[1:])
db.insert(objs[1:], collection="default_ef", model=None)
db.insert(objs[1:], collection="openai", model="openai:")
assert db.collection_metadata("default_ef").name == "default_ef"
assert db.collection_metadata("openai").name == "openai"
assert db.collection_metadata(None).model == "all-MiniLM-L6-v2"
assert db.collection_metadata("default_ef").model == "all-MiniLM-L6-v2"
assert db.collection_metadata("openai").model == "openai:"
assert db.collection_metadata("default_ef").venomx.id == "default_ef"
assert db.collection_metadata("openai").venomx.id == "openai"
assert db.collection_metadata(None).venomx.embedding_model.name == "all-MiniLM-L6-v2"
assert db.collection_metadata("default_ef").venomx.embedding_model.name == "all-MiniLM-L6-v2"
assert db.collection_metadata("openai").venomx.embedding_model.name == "openai:"
db.insert([objs[0]])
db.insert([objs[0]], collection="default_ef")
db.insert([objs[0]], collection="openai")
assert db.collection_metadata("default_ef").model == "all-MiniLM-L6-v2"
assert db.collection_metadata("openai").model == "openai:"
results_ef = list(db.search("fox", collection="default_ef"))
results_oai = list(db.search("fox", collection="openai"))
assert len(results_ef) > 0
Expand Down
6 changes: 2 additions & 4 deletions tests/wrappers/test_ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
logger.setLevel(logging.DEBUG)

# for debugging meanwhile implementing
@pytest.mark.skip
def test_insert_without_venomx():
db = setup_db(Path("../db"))
collection_name = "test_collection_without_venomx_set_upfront"
Expand Down Expand Up @@ -62,7 +61,6 @@ def test_insert_without_venomx():
print(metadata)

# for debugging meanwhile implementing
@pytest.mark.skip
def test_insert_with_venomx():
db = setup_db(Path("../db"))
collection_name = "test_collection_with_venomx_set_upfront"
Expand All @@ -84,8 +82,8 @@ def test_insert_with_venomx():
)
)

print(venomx)

# print(venomx)
# print("\n\n", venomx.id ,"\n\n")

db.insert(
wrapper.objects(),
Expand Down

0 comments on commit 1676367

Please sign in to comment.