From 0a82876cd9682762be158254404f10b8c3e53906 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Thu, 9 Jan 2025 08:28:30 +0100 Subject: [PATCH 01/17] Fix NIDX relation metadata indexing --- nidx/nidx_relation/src/resource_indexer.rs | 6 ++--- nidx/tests/test_search_relations.rs | 27 ++++++++++++++++++++-- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/nidx/nidx_relation/src/resource_indexer.rs b/nidx/nidx_relation/src/resource_indexer.rs index fda6200d12..f2f9558f43 100644 --- a/nidx/nidx_relation/src/resource_indexer.rs +++ b/nidx/nidx_relation/src/resource_indexer.rs @@ -37,7 +37,7 @@ pub fn index_relations( let source = relation.source.as_ref().expect("Missing source"); let source_value = source.value.as_str(); let source_type = io_maps::node_type_to_u64(source.ntype()); - let soruce_subtype = source.subtype.as_str(); + let source_subtype = source.subtype.as_str(); let target = relation.to.as_ref().expect("Missing target"); let target_value = target.value.as_str(); @@ -55,7 +55,7 @@ pub fn index_relations( schema.resource_id => resource_id, schema.source_value => source_value, schema.source_type => source_type, - schema.source_subtype => soruce_subtype, + schema.source_subtype => source_subtype, schema.target_value => target_value, schema.target_type => target_type, schema.target_subtype => target_subtype, @@ -65,7 +65,7 @@ pub fn index_relations( if let Some(metadata) = relation.metadata.as_ref() { let encoded_metadata = metadata.encode_to_vec(); - new_doc.add_bytes(schema.label, encoded_metadata); + new_doc.add_bytes(schema.metadata, encoded_metadata); } writer.add_document(new_doc)?; diff --git a/nidx/tests/test_search_relations.rs b/nidx/tests/test_search_relations.rs index 8844a8fcc7..63dca4a008 100644 --- a/nidx/tests/test_search_relations.rs +++ b/nidx/tests/test_search_relations.rs @@ -29,8 +29,9 @@ use nidx_protos::relation::RelationType; use nidx_protos::relation_node::NodeType; use nidx_protos::resource::ResourceStatus; use nidx_protos::{ - EntitiesSubgraphRequest, IndexMetadata, NewShardRequest, Relation, RelationNode, RelationNodeFilter, - RelationPrefixSearchRequest, RelationSearchRequest, RelationSearchResponse, Resource, ResourceId, + EntitiesSubgraphRequest, IndexMetadata, NewShardRequest, Relation, RelationMetadata, RelationNode, + RelationNodeFilter, RelationPrefixSearchRequest, RelationSearchRequest, RelationSearchResponse, Resource, + ResourceId, }; use nidx_protos::{SearchRequest, VectorIndexConfig}; use sqlx::PgPool; @@ -259,6 +260,15 @@ async fn create_knowledge_graph(fixture: &mut NidxFixture, shard_id: String) -> source: Some(relation_nodes.get("Poetry").unwrap().clone()), to: Some(relation_nodes.get("Swallow").unwrap().clone()), relation_label: "about".to_string(), + metadata: Some(RelationMetadata { + paragraph_id: Some("myresource/0/myresource/100-200".to_string()), + source_start: Some(0), + source_end: Some(10), + to_start: Some(11), + to_end: Some(20), + data_augmentation_task_id: Some("mytask".to_string()), + ..Default::default() + }), ..Default::default() }, Relation { @@ -642,6 +652,19 @@ async fn test_search_relations_neighbours(pool: PgPool) -> Result<(), Box Date: Thu, 9 Jan 2025 08:42:15 +0100 Subject: [PATCH 02/17] Fix also old index --- nucliadb_node/tests/test_search_relations.rs | 27 ++++++++++++++++++-- nucliadb_relations2/src/writer.rs | 6 ++--- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/nucliadb_node/tests/test_search_relations.rs b/nucliadb_node/tests/test_search_relations.rs index b90a964876..d4606c3a04 100644 --- a/nucliadb_node/tests/test_search_relations.rs +++ b/nucliadb_node/tests/test_search_relations.rs @@ -30,8 +30,9 @@ use nucliadb_core::protos::relation::RelationType; use nucliadb_core::protos::relation_node::NodeType; use nucliadb_core::protos::resource::ResourceStatus; use nucliadb_core::protos::{ - EntitiesSubgraphRequest, IndexMetadata, NewShardRequest, Relation, RelationNode, RelationNodeFilter, - RelationPrefixSearchRequest, RelationSearchRequest, RelationSearchResponse, Resource, ResourceId, + EntitiesSubgraphRequest, IndexMetadata, NewShardRequest, Relation, RelationMetadata, RelationNode, + RelationNodeFilter, RelationPrefixSearchRequest, RelationSearchRequest, RelationSearchResponse, Resource, + ResourceId, }; use nucliadb_protos::nodereader::SearchRequest; use rstest::*; @@ -260,6 +261,15 @@ async fn create_knowledge_graph(writer: &mut TestNodeWriter, shard_id: String) - source: Some(relation_nodes.get("Poetry").unwrap().clone()), to: Some(relation_nodes.get("Swallow").unwrap().clone()), relation_label: "about".to_string(), + metadata: Some(RelationMetadata { + paragraph_id: Some("myresource/0/myresource/100-200".to_string()), + source_start: Some(0), + source_end: Some(10), + to_start: Some(11), + to_end: Some(20), + data_augmentation_task_id: Some("mytask".to_string()), + ..Default::default() + }), ..Default::default() }, Relation { @@ -609,6 +619,19 @@ async fn test_search_relations_neighbours() -> Result<(), Box resource_id, self.schema.source_value => source_value, self.schema.source_type => source_type, - self.schema.source_subtype => soruce_subtype, + self.schema.source_subtype => source_subtype, self.schema.target_value => target_value, self.schema.target_type => target_type, self.schema.target_subtype => target_subtype, @@ -199,7 +199,7 @@ impl RelationsWriterService { if let Some(metadata) = relation.metadata.as_ref() { let encoded_metadata = metadata.encode_to_vec(); - new_doc.add_bytes(self.schema.label, encoded_metadata); + new_doc.add_bytes(self.schema.metadata, encoded_metadata); } self.writer.add_document(new_doc)?; From f08312e3f67de8d4235c0c726983551e993a8246 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Thu, 9 Jan 2025 09:07:22 +0100 Subject: [PATCH 03/17] Add unit tests --- nidx/nidx_relation/tests/common/mod.rs | 16 ++++++++ nucliadb_relations2/tests/common/mod.rs | 16 ++++++++ nucliadb_relations2/tests/test_reader.rs | 47 ++++++++++++++++++++++-- nucliadb_relations2/tests/test_writer.rs | 13 ++++++- 4 files changed, 87 insertions(+), 5 deletions(-) diff --git a/nidx/nidx_relation/tests/common/mod.rs b/nidx/nidx_relation/tests/common/mod.rs index 779a64fe6d..3df20738c7 100644 --- a/nidx/nidx_relation/tests/common/mod.rs +++ b/nidx/nidx_relation/tests/common/mod.rs @@ -49,3 +49,19 @@ pub fn create_relation( }), } } + +pub fn create_relation_with_metadata( + source: String, + source_node_type: NodeType, + source_subtype: String, + to: String, + to_node_type: NodeType, + to_subtype: String, + rel_type: RelationType, + metadata: RelationMetadata, +) -> Relation { + let mut relation = + create_relation(source, source_node_type, source_subtype, to, to_node_type, to_subtype, rel_type); + relation.metadata = Some(metadata); + relation +} diff --git a/nucliadb_relations2/tests/common/mod.rs b/nucliadb_relations2/tests/common/mod.rs index 4edbca3df5..47a4f1d898 100644 --- a/nucliadb_relations2/tests/common/mod.rs +++ b/nucliadb_relations2/tests/common/mod.rs @@ -49,3 +49,19 @@ pub fn create_relation( }), } } + +pub fn create_relation_with_metadata( + source: String, + source_node_type: NodeType, + source_subtype: String, + to: String, + to_node_type: NodeType, + to_subtype: String, + rel_type: RelationType, + metadata: RelationMetadata, +) -> Relation { + let mut relation = + create_relation(source, source_node_type, source_subtype, to, to_node_type, to_subtype, rel_type); + relation.metadata = Some(metadata); + relation +} diff --git a/nucliadb_relations2/tests/test_reader.rs b/nucliadb_relations2/tests/test_reader.rs index 98064a416c..28fd90a1d9 100644 --- a/nucliadb_relations2/tests/test_reader.rs +++ b/nucliadb_relations2/tests/test_reader.rs @@ -24,8 +24,8 @@ use nucliadb_core::protos::entities_subgraph_request::DeletedEntities; use nucliadb_core::protos::relation::RelationType; use nucliadb_core::protos::relation_node::NodeType; use nucliadb_core::protos::{ - EntitiesSubgraphRequest, RelationNodeFilter, RelationPrefixSearchRequest, RelationSearchRequest, Resource, - ResourceId, + EntitiesSubgraphRequest, RelationMetadata, RelationNodeFilter, RelationPrefixSearchRequest, RelationSearchRequest, + Resource, ResourceId, }; use nucliadb_core::relations::*; use nucliadb_relations2::reader::RelationsReaderService; @@ -93,7 +93,7 @@ fn create_reader() -> NodeResult { "PEOPLE".to_string(), RelationType::Entity, ), - common::create_relation( + common::create_relation_with_metadata( "Anthony".to_string(), NodeType::Entity, "PEOPLE".to_string(), @@ -101,6 +101,15 @@ fn create_reader() -> NodeResult { NodeType::Entity, "PLACES".to_string(), RelationType::Entity, + RelationMetadata { + paragraph_id: Some("myresource/0/myresource/100-200".to_string()), + source_start: Some(0), + source_end: Some(10), + to_start: Some(11), + to_end: Some(20), + data_augmentation_task_id: Some("mytask".to_string()), + ..Default::default() + }, ), common::create_relation( "Anna".to_string(), @@ -212,6 +221,38 @@ fn test_search() -> NodeResult<()> { Ok(()) } +#[test] +fn test_search_metadata() -> NodeResult<()> { + let reader = create_reader()?; + + let result = reader.search(&RelationSearchRequest { + subgraph: Some(EntitiesSubgraphRequest { + depth: Some(1_i32), + entry_points: vec![common::create_relation_node( + "Anthony".to_string(), + NodeType::Entity, + "PEOPLE".to_string(), + )], + ..Default::default() + }), + ..Default::default() + })?; + + let subgraph = result.subgraph.unwrap(); + assert_eq!(subgraph.relations.len(), 1); + + let relation = &subgraph.relations[0]; + let metadata = relation.metadata.as_ref().unwrap(); + assert_eq!(metadata.paragraph_id, Some("myresource/0/myresource/100-200".to_string())); + assert_eq!(metadata.source_start, Some(0)); + assert_eq!(metadata.source_end, Some(10)); + assert_eq!(metadata.to_start, Some(11)); + assert_eq!(metadata.to_end, Some(20)); + assert_eq!(metadata.data_augmentation_task_id, Some("mytask".to_string())); + + Ok(()) +} + #[test] fn test_prefix_search() -> NodeResult<()> { let reader = create_reader()?; diff --git a/nucliadb_relations2/tests/test_writer.rs b/nucliadb_relations2/tests/test_writer.rs index 57934fa3fb..fa1e2b4d37 100644 --- a/nucliadb_relations2/tests/test_writer.rs +++ b/nucliadb_relations2/tests/test_writer.rs @@ -22,7 +22,7 @@ mod common; use nucliadb_core::prelude::*; use nucliadb_core::protos::relation::RelationType; use nucliadb_core::protos::relation_node::NodeType; -use nucliadb_core::protos::{Resource, ResourceId}; +use nucliadb_core::protos::{RelationMetadata, Resource, ResourceId}; use nucliadb_core::relations::*; use nucliadb_relations2::writer::RelationsWriterService; use tempfile::TempDir; @@ -69,7 +69,7 @@ fn test_index_docs() -> NodeResult<()> { shard_id: "shard_id".to_string(), }), relations: vec![ - common::create_relation( + common::create_relation_with_metadata( "cat".to_string(), NodeType::Entity, "ANIMALS".to_string(), @@ -77,6 +77,15 @@ fn test_index_docs() -> NodeResult<()> { NodeType::Entity, "ANIMALS".to_string(), RelationType::Entity, + RelationMetadata { + paragraph_id: Some("myresource/0/myresource/100-200".to_string()), + source_start: Some(0), + source_end: Some(10), + to_start: Some(11), + to_end: Some(20), + data_augmentation_task_id: Some("mytask".to_string()), + ..Default::default() + }, ), common::create_relation( "dolphin".to_string(), From 19faab6c019bf8d593cca5d49110c7ee6aa03c51 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Thu, 9 Jan 2025 09:07:33 +0100 Subject: [PATCH 04/17] Add unit tests to nidx --- nidx/nidx_relation/tests/test_reader.rs | 47 +++++++++++++++++++++++-- nidx/nidx_relation/tests/test_writer.rs | 13 +++++-- 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/nidx/nidx_relation/tests/test_reader.rs b/nidx/nidx_relation/tests/test_reader.rs index 953915de6f..fe7b4ffba0 100644 --- a/nidx/nidx_relation/tests/test_reader.rs +++ b/nidx/nidx_relation/tests/test_reader.rs @@ -23,8 +23,8 @@ use nidx_protos::entities_subgraph_request::DeletedEntities; use nidx_protos::relation::RelationType; use nidx_protos::relation_node::NodeType; use nidx_protos::{ - EntitiesSubgraphRequest, RelationNodeFilter, RelationPrefixSearchRequest, RelationSearchRequest, Resource, - ResourceId, + EntitiesSubgraphRequest, RelationMetadata, RelationNodeFilter, RelationPrefixSearchRequest, RelationSearchRequest, + Resource, ResourceId, }; use nidx_relation::{RelationIndexer, RelationSearcher}; use nidx_tantivy::{TantivyMeta, TantivySegmentMetadata}; @@ -109,7 +109,7 @@ fn create_reader() -> anyhow::Result { "PEOPLE".to_string(), RelationType::Entity, ), - common::create_relation( + common::create_relation_with_metadata( "Anthony".to_string(), NodeType::Entity, "PEOPLE".to_string(), @@ -117,6 +117,15 @@ fn create_reader() -> anyhow::Result { NodeType::Entity, "PLACES".to_string(), RelationType::Entity, + RelationMetadata { + paragraph_id: Some("myresource/0/myresource/100-200".to_string()), + source_start: Some(0), + source_end: Some(10), + to_start: Some(11), + to_end: Some(20), + data_augmentation_task_id: Some("mytask".to_string()), + ..Default::default() + }, ), common::create_relation( "Anna".to_string(), @@ -203,6 +212,38 @@ fn test_search() -> anyhow::Result<()> { Ok(()) } +#[test] +fn test_search_metadata() -> anyhow::Result<()> { + let reader = create_reader()?; + + let result = reader.search(&RelationSearchRequest { + subgraph: Some(EntitiesSubgraphRequest { + depth: Some(1_i32), + entry_points: vec![common::create_relation_node( + "Anthony".to_string(), + NodeType::Entity, + "PEOPLE".to_string(), + )], + ..Default::default() + }), + ..Default::default() + })?; + + let subgraph = result.subgraph.unwrap(); + assert_eq!(subgraph.relations.len(), 1); + + let relation = &subgraph.relations[0]; + let metadata = relation.metadata.as_ref().unwrap(); + assert_eq!(metadata.paragraph_id, Some("myresource/0/myresource/100-200".to_string())); + assert_eq!(metadata.source_start, Some(0)); + assert_eq!(metadata.source_end, Some(10)); + assert_eq!(metadata.to_start, Some(11)); + assert_eq!(metadata.to_end, Some(20)); + assert_eq!(metadata.data_augmentation_task_id, Some("mytask".to_string())); + + Ok(()) +} + #[test] fn test_prefix_search() -> anyhow::Result<()> { let reader = create_reader()?; diff --git a/nidx/nidx_relation/tests/test_writer.rs b/nidx/nidx_relation/tests/test_writer.rs index 1068188c2b..dda78cc50f 100644 --- a/nidx/nidx_relation/tests/test_writer.rs +++ b/nidx/nidx_relation/tests/test_writer.rs @@ -19,7 +19,7 @@ // mod common; -use nidx_protos::{relation::RelationType, relation_node::NodeType, Resource, ResourceId}; +use nidx_protos::{relation::RelationType, relation_node::NodeType, RelationMetadata, Resource, ResourceId}; use nidx_relation::RelationIndexer; use tempfile::TempDir; @@ -42,7 +42,7 @@ fn test_index_docs() -> anyhow::Result<()> { "ANIMALS".to_string(), RelationType::Entity, ), - common::create_relation( + common::create_relation_with_metadata( "01808bbd8e784552967a4fb0d8b6e584".to_string(), NodeType::Resource, "".to_string(), @@ -50,6 +50,15 @@ fn test_index_docs() -> anyhow::Result<()> { NodeType::Entity, "ANIMALS".to_string(), RelationType::Entity, + RelationMetadata { + paragraph_id: Some("myresource/0/myresource/100-200".to_string()), + source_start: Some(0), + source_end: Some(10), + to_start: Some(11), + to_end: Some(20), + data_augmentation_task_id: Some("mytask".to_string()), + ..Default::default() + }, ), ], ..Default::default() From 63f00d52fa3b92556367ddf57aeefa27236123df Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Fri, 10 Jan 2025 15:35:50 +0100 Subject: [PATCH 05/17] Graph Strategy begins --- nucliadb/src/nucliadb/common/ids.py | 22 +- .../src/nucliadb/search/search/chat/ask.py | 45 ++- .../src/nucliadb/search/search/chat/prompt.py | 6 +- .../src/nucliadb/search/search/chat/query.py | 108 ++++- .../src/nucliadb/search/search/find_merge.py | 33 ++ .../nucliadb/search/search/graph_strategy.py | 370 ++++++++++++++++++ nucliadb/src/nucliadb/search/search/merge.py | 15 +- nucliadb_models/src/nucliadb_models/search.py | 31 ++ 8 files changed, 598 insertions(+), 32 deletions(-) create mode 100644 nucliadb/src/nucliadb/search/search/graph_strategy.py diff --git a/nucliadb/src/nucliadb/common/ids.py b/nucliadb/src/nucliadb/common/ids.py index 9a954ac204..ff35901a0d 100644 --- a/nucliadb/src/nucliadb/common/ids.py +++ b/nucliadb/src/nucliadb/common/ids.py @@ -111,13 +111,11 @@ def from_string(cls, value: str) -> "FieldId": parts = value.split("/") if len(parts) == 3: rid, _type, key = parts - if _type not in FIELD_TYPE_STR_TO_PB: - raise ValueError(f"Invalid FieldId: {value}") + _type = cls.parse_field_type(_type) return cls(rid=rid, type=_type, key=key) elif len(parts) == 4: rid, _type, key, subfield_id = parts - if _type not in FIELD_TYPE_STR_TO_PB: - raise ValueError(f"Invalid FieldId: {value}") + _type = cls.parse_field_type(_type) return cls( rid=rid, type=_type, @@ -127,6 +125,22 @@ def from_string(cls, value: str) -> "FieldId": else: raise ValueError(f"Invalid FieldId: {value}") + @classmethod + def parse_field_type(cls, _type: str) -> str: + if _type not in FIELD_TYPE_STR_TO_PB: + # Try to parse the enum value + # XXX: This is to support field types that are integer values of FieldType + # Which is how legacy processor relations reported the paragraph_id + try: + type_pb = FieldType.ValueType(_type) + except ValueError: + raise ValueError(f"Invalid FieldId: {_type}") + if type_pb in FIELD_TYPE_PB_TO_STR: + return FIELD_TYPE_PB_TO_STR[type_pb] + else: + raise ValueError(f"Invalid FieldId: {_type}") + return _type + @dataclass class ParagraphId: diff --git a/nucliadb/src/nucliadb/search/search/chat/ask.py b/nucliadb/src/nucliadb/search/search/chat/ask.py index 1d6f4ea92b..1647f60d1f 100644 --- a/nucliadb/src/nucliadb/search/search/chat/ask.py +++ b/nucliadb/src/nucliadb/search/search/chat/ask.py @@ -48,6 +48,7 @@ NOT_ENOUGH_CONTEXT_ANSWER, ChatAuditor, get_find_results, + get_graph_results, get_relations_results, rephrase_query, sorted_prompt_context_list, @@ -75,6 +76,7 @@ ErrorAskResponseItem, FindParagraph, FindRequest, + GraphStrategy, JSONAskResponseItem, KnowledgeboxFindResults, MetadataAskResponseItem, @@ -629,6 +631,13 @@ def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]: return None +def parse_graph_strategy(ask_request: AskRequest) -> Optional[GraphStrategy]: + for rag_strategy in ask_request.rag_strategies: + if rag_strategy.name == RagStrategyName.GRAPH: + return cast(GraphStrategy, rag_strategy) + return None + + async def retrieval_step( kbid: str, main_query: str, @@ -675,17 +684,33 @@ async def retrieval_in_kb( metrics: RAGMetrics, ) -> RetrievalResults: prequeries = parse_prequeries(ask_request) + graph_strategy = parse_graph_strategy(ask_request) with metrics.time("retrieval"): - main_results, prequeries_results, query_parser = await get_find_results( - kbid=kbid, - query=main_query, - item=ask_request, - ndb_client=client_type, - user=user_id, - origin=origin, - metrics=metrics, - prequeries_strategy=prequeries, - ) + prequeries_results = None + if graph_strategy is not None: + main_results, query_parser = await get_graph_results( + kbid=kbid, + query=main_query, + item=ask_request, + ndb_client=client_type, + user=user_id, + origin=origin, + graph_strategy=graph_strategy, + metrics=metrics, + shards=ask_request.shards, + ) + # TODO (oni): Fallback to normal retrieval if no graph results are found + else: + main_results, prequeries_results, query_parser = await get_find_results( + kbid=kbid, + query=main_query, + item=ask_request, + ndb_client=client_type, + user=user_id, + origin=origin, + metrics=metrics, + prequeries_strategy=prequeries, + ) if len(main_results.resources) == 0 and all( len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or [] ): diff --git a/nucliadb/src/nucliadb/search/search/chat/prompt.py b/nucliadb/src/nucliadb/search/search/chat/prompt.py index 6874d30824..f418d6ab0f 100644 --- a/nucliadb/src/nucliadb/search/search/chat/prompt.py +++ b/nucliadb/src/nucliadb/search/search/chat/prompt.py @@ -1012,8 +1012,10 @@ async def _build_context(self, context: CappedPromptContext) -> None: neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy) elif strategy.name == RagStrategyName.METADATA_EXTENSION: metadata_extension = cast(MetadataExtensionStrategy, strategy) - elif strategy.name != RagStrategyName.PREQUERIES: # pragma: no cover - # Prequeries are not handled here + elif ( + strategy.name != RagStrategyName.PREQUERIES and strategy.name != RagStrategyName.GRAPH + ): # pragma: no cover + # Prequeries and graph are not handled here logger.warning( "Unknown rag strategy", extra={"strategy": strategy.name, "kbid": self.kbid}, diff --git a/nucliadb/src/nucliadb/search/search/chat/query.py b/nucliadb/src/nucliadb/search/search/chat/query.py index e47e95ada9..7ad9e8f7d5 100644 --- a/nucliadb/src/nucliadb/search/search/chat/query.py +++ b/nucliadb/src/nucliadb/search/search/chat/query.py @@ -25,7 +25,8 @@ from nucliadb.search.requesters.utils import Method, node_query from nucliadb.search.search.chat.exceptions import NoRetrievalResultsError from nucliadb.search.search.exceptions import IncompleteFindResultsError -from nucliadb.search.search.find import find +from nucliadb.search.search.find import find, query_parser_from_find_request +from nucliadb.search.search.graph_strategy import build_graph_response, rank_relations from nucliadb.search.search.merge import merge_relations_results from nucliadb.search.search.metrics import RAGMetrics from nucliadb.search.search.query import QueryParser @@ -36,6 +37,7 @@ ChatContextMessage, ChatOptions, FindRequest, + GraphStrategy, KnowledgeboxFindResults, NucliaDBClientType, PreQueriesStrategy, @@ -75,6 +77,81 @@ async def rephrase_query( return await predict.rephrase_query(kbid, req) +async def get_graph_results( + *, + kbid: str, + query: str, + item: AskRequest, + ndb_client: NucliaDBClientType, + user: str, + origin: str, + graph_strategy: GraphStrategy, + generative_model: Optional[str] = None, + metrics: RAGMetrics = RAGMetrics(), + shards: Optional[list[str]] = None, +) -> tuple[KnowledgeboxFindResults, QueryParser]: + # TODO: Multi hop + + # 1. Get relations from entities in query + # TODO: Send flag to predict entities to use DA entities once available + relations = await get_relations_results( + kbid=kbid, + text_answer=query, + timeout=5.0, + target_shard_replicas=shards, + only_with_metadata=True, + # use_da_entities=True, + ) + """ + Relations(entities={'Ministry of Environment': EntitySubgraph(related_to=[DirectionalRelation(entity='Ordesa National Park', entity_type=, relation=, relation_label='subsidiary', direction=), DirectionalRelation(entity='National_parks_of_Spain', entity_type=, relation=, relation_label='subsidiary', direction=), DirectionalRelation(entity='cfcde433fc2a4e388360d3d32b03729f', entity_type=, relation=, relation_label='', direction=)])}) + """ + import pdb + + pdb.set_trace() + + # + # 2. Rank the relations and get the top_k + # TODO: Evaluate using suggest + pruned_relations = await rank_relations( + relations, query, kbid, user, top_k=graph_strategy.top_k, generative_model=generative_model + ) + + import pdb + + pdb.set_trace() + + # 3. Get the text for the top_k relations + # XXX: We could use the location of head entity and tail entity to get the text instead + # that would result in way less text to retrieve + paragraph_ids = { + r.metadata.paragraph_id + for r in pruned_relations.entities.values() + for r in r.related_to + if r.metadata and r.metadata.paragraph_id + } + find_request = find_request_from_ask_request(item, query) + query_parser, rank_fusion, reranker = await query_parser_from_find_request( + kbid, find_request, generative_model=generative_model + ) + find_results = await build_graph_response( + paragraph_ids, + kbid=kbid, + query=query, + final_relations=pruned_relations, + top_k=graph_strategy.top_k, + reranker=reranker, + show=find_request.show, + extracted=find_request.extracted, + field_type_filter=find_request.field_type_filter, + ) + import pdb + + pdb.set_trace() + # TODO: Report using RAGMetrics + + return find_results, query_parser + + async def get_find_results( *, kbid: str, @@ -144,15 +221,7 @@ async def get_find_results( return main_results, prequeries_results, query_parser -async def run_main_query( - kbid: str, - query: str, - item: AskRequest, - ndb_client: NucliaDBClientType, - user: str, - origin: str, - metrics: RAGMetrics = RAGMetrics(), -) -> tuple[KnowledgeboxFindResults, QueryParser]: +def find_request_from_ask_request(item: AskRequest, query: str) -> FindRequest: find_request = FindRequest() find_request.resource_filters = item.resource_filters find_request.features = [] @@ -188,7 +257,19 @@ async def run_main_query( find_request.show_hidden = item.show_hidden # this executes the model validators, that can tweak some fields - FindRequest.model_validate(find_request) + return FindRequest.model_validate(find_request) + + +async def run_main_query( + kbid: str, + query: str, + item: AskRequest, + ndb_client: NucliaDBClientType, + user: str, + origin: str, + metrics: RAGMetrics = RAGMetrics(), +) -> tuple[KnowledgeboxFindResults, QueryParser]: + find_request = find_request_from_ask_request(item, query) find_results, incomplete, query_parser = await find( kbid, @@ -210,6 +291,7 @@ async def get_relations_results( text_answer: str, target_shard_replicas: Optional[list[str]], timeout: Optional[float] = None, + only_with_metadata: bool = False, ) -> Relations: try: predict = get_predict() @@ -233,7 +315,9 @@ async def get_relations_results( retry_on_primary=False, ) relations_results: list[RelationSearchResponse] = [result.relation for result in results] - return await merge_relations_results(relations_results, request.relation_subgraph) + return await merge_relations_results( + relations_results, request.relation_subgraph, only_with_metadata + ) except Exception as exc: capture_exception(exc) logger.exception("Error getting relations results") diff --git a/nucliadb/src/nucliadb/search/search/find_merge.py b/nucliadb/src/nucliadb/search/search/find_merge.py index 5a163545e9..8e709ee50f 100644 --- a/nucliadb/src/nucliadb/search/search/find_merge.py +++ b/nucliadb/src/nucliadb/search/search/find_merge.py @@ -226,6 +226,39 @@ def merge_shards_relation_responses( return merged +def paragraph_id_to_text_block_match(paragraph_id: str) -> TextBlockMatch: + """ + Given a paragraph_id, return a TextBlockMatch with the bare minimum fields + This is required by the Graph Strategy to get text blocks from the relevant paragraphs + """ + parsed_paragraph_id = ParagraphId.from_string(paragraph_id) + return TextBlockMatch( + paragraph_id=parsed_paragraph_id, + score=0, + score_type=SCORE_TYPE.BM25, + order=0, # NOTE: this will be filled later + text="", # NOTE: this will be filled later too + position=TextPosition( + page_number=0, + index=0, + start=parsed_paragraph_id.paragraph_start, + end=parsed_paragraph_id.paragraph_end, + start_seconds=[], + end_seconds=[], + ), + field_labels=[], + paragraph_labels=[], + fuzzy_search=False, + is_a_table=False, + representation_file="", + page_with_visual=False, + ) + + +def paragraph_id_to_text_block_matches(paragraph_ids: Iterable[str]) -> list[TextBlockMatch]: + return [paragraph_id_to_text_block_match(item) for item in paragraph_ids] + + def keyword_result_to_text_block_match(item: ParagraphResult) -> TextBlockMatch: fuzzy_result = len(item.matches) > 0 return TextBlockMatch( diff --git a/nucliadb/src/nucliadb/search/search/graph_strategy.py b/nucliadb/src/nucliadb/search/search/graph_strategy.py new file mode 100644 index 0000000000..d9b2001630 --- /dev/null +++ b/nucliadb/src/nucliadb/search/search/graph_strategy.py @@ -0,0 +1,370 @@ +import heapq +import json +from collections import defaultdict +from typing import Iterable, Optional + +from nuclia_models.predict.generative_responses import ( + JSONGenerativeResponse, + MetaGenerativeResponse, + StatusGenerativeResponse, +) + +from nucliadb.search.search.find_merge import ( + compose_find_resources, + hydrate_and_rerank, + paragraph_id_to_text_block_matches, +) +from nucliadb.search.search.hydrator import ResourceHydrationOptions, TextBlockHydrationOptions +from nucliadb.search.search.rerankers import Reranker, RerankingOptions +from nucliadb.search.utilities import get_predict +from nucliadb_models.common import FieldTypeName +from nucliadb_models.resource import ExtractedDataTypeName +from nucliadb_models.search import ( + ChatModel, + DirectionalRelation, + EntitySubgraph, + KnowledgeboxFindResults, + RelationDirection, + Relations, + ResourceProperties, + UserPrompt, +) + +SCHEMA = { + "title": "score_triplets", + "description": "Return a list of triplets and their relevance scores (0-10) for the supplied question.", + "type": "object", + "properties": { + "triplets": { + "type": "array", + "description": "A list of triplets with their relevance scores.", + "items": { + "type": "object", + "properties": { + "head_entity": {"type": "string", "description": "The first entity in the triplet."}, + "relationship": { + "type": "string", + "description": "The relationship between the two entities.", + }, + "tail_entity": { + "type": "string", + "description": "The second entity in the triplet.", + }, + "score": { + "type": "integer", + "description": "A relevance score in the range 0 to 10.", + "minimum": 0, + "maximum": 10, + }, + }, + "required": ["head_entity", "relationship", "tail_entity", "score"], + }, + } + }, + "required": ["triplets"], +} + +PROMPT = """\ +You are an advanced language model assisting in scoring relationships (edges) between two entities in a knowledge graph, given a user’s question. + +For each provided **(head_entity, relationship, tail_entity)**, you must: +1. Assign a **relevance score** between **0** and **10**. +2. **0** means “this relationship can’t be relevant at all to the question.” +3. **10** means “this relationship is extremely relevant to the question.” +4. You may use **any integer** between 0 and 10 (e.g., 3, 7, etc.) based on how relevant you deem the relationship to be. +5. **Language Agnosticism**: The question and the relationships may be in different languages. The relevance scoring should still work and be agnostic of the language. +6. Relationships that may not answer the question directly but expand knowledge in a relevant way, should also be scored positively. + +Once you have decided the best score for each triplet, return these results **using a function call** in JSON format with the following rules: + +- The function name should be `score_triplets`. +- The first argument should be the list of triplets. +- Each triplet should have the following keys: + - `head_entity`: The first entity in the triplet. + - `relationship`: The relationship between the two entities. + - `tail_entity`: The second entity in the triplet. + - `score`: The relevance score in the range 0 to 10. + +You **must** comply with the provided JSON Schema to ensure a well-structured response and mantain the order of the triplets. + + +## Examples: + +### Example 1: + +**Input** + +{ + "question": "Who is the mayor of the capital city of Australia?", + "triplets": [ + { + "head_entity": "Australia", + "relationship": "has prime minister", + "tail_entity": "Scott Morrison" + }, + { + "head_entity": "Canberra", + "relationship": "is capital of", + "tail_entity": "Australia" + }, + { + "head_entity": "Scott Knowles", + "relationship": "holds position", + "tail_entity": "Mayor" + }, + { + "head_entity": "Barbera Smith", + "relationship": "tiene cargo", + "tail_entity": "Alcalde" + }, + { + "head_entity": "Austria", + "relationship": "has capital", + "tail_entity": "Vienna" + } + ] +} + +**Output** + +{ + "triplets": [ + { + "head_entity": "Australia", + "relationship": "has prime minister", + "tail_entity": "Scott Morrison", + "score": 4 + }, + { + "head_entity": "Canberra", + "relationship": "is capital of", + "tail_entity": "Australia", + "score": 8 + }, + { + "head_entity": "Scott Knowles", + "relationship": "holds position", + "tail_entity": "Mayor", + "score": 8 + }, + { + "head_entity": "Barbera Smith", + "relationship": "tiene cargo", + "tail_entity": "Alcalde", + "score": 8 + }, + { + "head_entity": "Austria", + "relationship": "has capital", + "tail_entity": "Vienna", + "score": 0 + } + ] +} + + + +### Example 2: + +**Input** + +{ + "question": "How many products does John Adams Roofing Inc. offer?", + "triplets": [ + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "has product", + "tail_entity": "Titanium Grade 3 Roofing Nails" + }, + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "is located in", + "tail_entity": "New York" + }, + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "was founded by", + "tail_entity": "John Adams" + }, + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "tiene stock", + "tail_entity": "Baldosas solares" + }, + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "has product", + "tail_entity": "Mercerized Cotton Thread" + } + ] +} + +**Output** + +{ + "triplets": [ + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "has product", + "tail_entity": "Titanium Grade 3 Roofing Nails", + "score": 10 + }, + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "is located in", + "tail_entity": "New York", + "score": 6 + }, + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "was founded by", + "tail_entity": "John Adams", + "score": 5 + }, + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "tiene stock", + "tail_entity": "Baldosas solares", + "score": 10 + }, + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "has product", + "tail_entity": "Mercerized Cotton Thread", + "score": 10 + } + ] +} + +Now, let's get started! Here are the triplets you need to score: + +**Input** + +""" + + +async def rank_relations( + relations: Relations, + query: str, + kbid: str, + user: str, + top_k: int, + generative_model: Optional[str] = None, +) -> Relations: + # Store the index for keeping track after scoring + flat_rels: list[tuple[str, int, DirectionalRelation]] = [ + (ent, idx, rel) + for (ent, rels) in relations.entities.items() + for (idx, rel) in enumerate(rels.related_to) + ] + triplets: list[dict[str, str]] = [ + { + "head_entity": ent, + "relationship": rel.relation_label, + "tail_entity": rel.entity, + } + if rel.direction == RelationDirection.OUT + else { + "head_entity": rel.entity, + "relationship": rel.relation_label, + "tail_entity": ent, + } + for (ent, _, rel) in flat_rels + ] + data = { + "question": query, + "triplets": triplets, + } + prompt = PROMPT + json.dumps(data, indent=4) + + predict = get_predict() + item = ChatModel( + question=prompt, + user_id=user, + json_schema=SCHEMA, + format_prompt=False, # We supply our own prompt + query_context_order={}, + query_context={}, + user_prompt=UserPrompt(prompt=prompt), + max_tokens=4096, + generative_model=generative_model, + ) + # TODO: Enclose this in a try-except block + ident, model, answer_stream = await predict.chat_query_ndjson(kbid, item) + response_json = None + status = None + meta = None + + async for generative_chunk in answer_stream: + item = generative_chunk.chunk + if isinstance(item, JSONGenerativeResponse): + response_json = item + elif isinstance(item, StatusGenerativeResponse): + status = item + elif isinstance(item, MetaGenerativeResponse): + meta = item + else: + # TODO: Improve for logging + raise ValueError(f"Unknown generative chunk type: {item}") + + # TODO: Report tokens using meta? + + if response_json is None or status is None or status.code != "0": + raise ValueError("No JSON response found") + + scored_triplets = response_json.object["triplets"] + + if len(scored_triplets) != len(flat_rels): + raise ValueError("Mismatch between input and output triplets") + scores = ((idx, scored_triplet["score"]) for (idx, scored_triplet) in enumerate(scored_triplets)) + top_k_scores = heapq.nlargest(top_k, scores, key=lambda x: x[1]) + top_k_rels = defaultdict(lambda: EntitySubgraph(related_to=[])) + for idx_flat, _ in top_k_scores: + (ent, idx, _) = flat_rels[idx_flat] + rel = relations.entities[ent].related_to[idx] + top_k_rels[ent].related_to.append(rel) + + return Relations(entities=top_k_rels) + + +async def build_graph_response( + paragraph_ids: Iterable[str], + *, + kbid: str, + query: str, + final_relations: Relations, + top_k: int, + reranker: Reranker, + show: list[ResourceProperties] = [], + extracted: list[ExtractedDataTypeName] = [], + field_type_filter: list[FieldTypeName] = [], +) -> KnowledgeboxFindResults: + # manually generate paragraph results + + text_blocks = paragraph_id_to_text_block_matches(paragraph_ids) + + # hydrate and rerank + resource_hydration_options = ResourceHydrationOptions( + show=show, extracted=extracted, field_type_filter=field_type_filter + ) + text_block_hydration_options = TextBlockHydrationOptions() + reranking_options = RerankingOptions(kbid=kbid, query=query) + text_blocks, resources, best_matches = await hydrate_and_rerank( + text_blocks, + kbid, + resource_hydration_options=resource_hydration_options, + text_block_hydration_options=text_block_hydration_options, + reranker=reranker, + reranking_options=reranking_options, + top_k=top_k, + ) + + find_resources = compose_find_resources(text_blocks, resources) + + return KnowledgeboxFindResults( + query=query, + resources=find_resources, + best_matches=best_matches, + relations=final_relations, + total=len(text_blocks), + ) diff --git a/nucliadb/src/nucliadb/search/search/merge.py b/nucliadb/src/nucliadb/search/search/merge.py index 4fcdee3b77..abace7d5f9 100644 --- a/nucliadb/src/nucliadb/search/search/merge.py +++ b/nucliadb/src/nucliadb/search/search/merge.py @@ -33,7 +33,7 @@ ) from nucliadb_models.common import FieldTypeName from nucliadb_models.labels import translate_system_to_alias_label -from nucliadb_models.metadata import RelationTypePbMap +from nucliadb_models.metadata import RelationMetadata, RelationTypePbMap from nucliadb_models.resource import ExtractedDataTypeName from nucliadb_models.search import ( DirectionalRelation, @@ -432,14 +432,18 @@ async def merge_paragraph_results( async def merge_relations_results( relations_responses: list[RelationSearchResponse], query: EntitiesSubgraphRequest, + only_with_metadata: bool = False, ) -> Relations: loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, _merge_relations_results, relations_responses, query) + return await loop.run_in_executor( + None, _merge_relations_results, relations_responses, query, only_with_metadata + ) def _merge_relations_results( relations_responses: list[RelationSearchResponse], query: EntitiesSubgraphRequest, + only_with_metadata: bool, ) -> Relations: relations = Relations(entities={}) @@ -452,8 +456,9 @@ def _merge_relations_results( destination = relation.to relation_type = RelationTypePbMap[relation.relation] relation_label = relation.relation_label + metadata = relation.metadata if relation.HasField("metadata") else None - if origin.value in relations.entities: + if (not only_with_metadata or metadata) and origin.value in relations.entities: relations.entities[origin.value].related_to.append( DirectionalRelation( entity=destination.value, @@ -461,9 +466,10 @@ def _merge_relations_results( relation=relation_type, relation_label=relation_label, direction=RelationDirection.OUT, + metadata=RelationMetadata.from_message(metadata) if metadata else None, ) ) - elif destination.value in relations.entities: + elif (not only_with_metadata or metadata) and destination.value in relations.entities: relations.entities[destination.value].related_to.append( DirectionalRelation( entity=origin.value, @@ -471,6 +477,7 @@ def _merge_relations_results( relation=relation_type, relation_label=relation_label, direction=RelationDirection.IN, + metadata=RelationMetadata.from_message(metadata) if metadata else None, ) ) diff --git a/nucliadb_models/src/nucliadb_models/search.py b/nucliadb_models/src/nucliadb_models/search.py index 005047f25e..e3b4061ccf 100644 --- a/nucliadb_models/src/nucliadb_models/search.py +++ b/nucliadb_models/src/nucliadb_models/search.py @@ -25,6 +25,7 @@ from pydantic.json_schema import SkipJsonSchema from typing_extensions import Annotated, Self, deprecated +from nucliadb_models import RelationMetadata from nucliadb_models.common import FieldTypeName, ParamDefault # Bw/c import to avoid breaking users @@ -256,6 +257,7 @@ class DirectionalRelation(BaseModel): relation: RelationType relation_label: str direction: RelationDirection + metadata: Optional[RelationMetadata] = None class EntitySubgraph(BaseModel): @@ -981,6 +983,11 @@ class ChatModel(BaseModel): ) top_k: Optional[int] = Field(default=None, description="Number of best elements to get from") + format_prompt: bool = Field( + default=True, + description="If set to false, the prompt will be used as is, without any formatting for query or context", + ) + class RephraseModel(BaseModel): question: str @@ -1002,6 +1009,7 @@ class RagStrategyName: METADATA_EXTENSION = "metadata_extension" PREQUERIES = "prequeries" CONVERSATION = "conversation" + GRAPH = "graph" class ImageRagStrategyName: @@ -1230,6 +1238,27 @@ class PreQueriesStrategy(RagStrategy): PreQueryResult = tuple[PreQuery, "KnowledgeboxFindResults"] +class GraphStrategy(RagStrategy): + """ + This strategy retrieves context pieces by exploring the Knowledge Graph, starting from the entities present in the query. + It works best if the Knowledge Box has a user-defined Graph Extraction agent enabled. + """ + + name: Literal["graph"] = "graph" + n_hops: int = Field( + default=1, + title="Number of hops", + description="Number of hops to take when exploring the graph for relevant context. Biggers values will take more time to compute .", + ge=1, + ) + top_k: int = Field( + default=20, + title="Top k", + description="Number of relationships to keep after each hop. This number correlates to more paragraphs being sent as context.", + ge=1, + ) + + class TableImageStrategy(ImageRagStrategy): name: Literal["tables"] = "tables" @@ -1256,6 +1285,7 @@ class ParagraphImageStrategy(ImageRagStrategy): MetadataExtensionStrategy, ConversationalStrategy, PreQueriesStrategy, + GraphStrategy, ], Field(discriminator="name"), ] @@ -1405,6 +1435,7 @@ class AskRequest(AuditMetadataBase): - `neighbouring_paragraphs` will add the sorrounding paragraphs to the context for each matching paragraph. - `metadata_extension` will add the metadata of the matching paragraphs or its resources to the context. - `prequeries` allows to run multiple retrieval queries before the main query and add the results to the context. The results of specific queries can be boosted by the specifying weights. +- `graph` will retrieve context pieces by exploring the Knowledge Graph, starting from the entities present in the query. If empty, the default strategy is used, which simply adds the text of the matching paragraphs to the context. """ From 17261aa9d200ca82e026ebf5a48edcff283a5bf5 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Mon, 13 Jan 2025 09:59:29 +0100 Subject: [PATCH 06/17] Use suggest for fuzzy matching entities instead of predict --- .../src/nucliadb/search/search/chat/query.py | 117 ++++++++++++------ .../nucliadb/search/search/graph_strategy.py | 70 +++++++++++ nucliadb/src/nucliadb/search/search/merge.py | 6 +- 3 files changed, 151 insertions(+), 42 deletions(-) diff --git a/nucliadb/src/nucliadb/search/search/chat/query.py b/nucliadb/src/nucliadb/search/search/chat/query.py index 7667e3bb65..d43fef516d 100644 --- a/nucliadb/src/nucliadb/search/search/chat/query.py +++ b/nucliadb/src/nucliadb/search/search/chat/query.py @@ -18,7 +18,7 @@ # along with this program. If not, see . # import asyncio -from typing import Optional +from typing import Iterable, Optional from nucliadb.common.models_utils import to_proto from nucliadb.search import logger @@ -27,7 +27,11 @@ from nucliadb.search.search.chat.exceptions import NoRetrievalResultsError from nucliadb.search.search.exceptions import IncompleteFindResultsError from nucliadb.search.search.find import find, query_parser_from_find_request -from nucliadb.search.search.graph_strategy import build_graph_response, rank_relations +from nucliadb.search.search.graph_strategy import ( + build_graph_response, + fuzzy_search_entities, + rank_relations, +) from nucliadb.search.search.merge import merge_relations_results from nucliadb.search.search.metrics import RAGMetrics from nucliadb.search.search.query import QueryParser @@ -53,6 +57,7 @@ ) from nucliadb_protos import audit_pb2 from nucliadb_protos.nodereader_pb2 import RelationSearchResponse, SearchRequest, SearchResponse +from nucliadb_protos.utils_pb2 import RelationNode from nucliadb_telemetry.errors import capture_exception from nucliadb_utils.utilities import get_audit @@ -92,27 +97,47 @@ async def get_graph_results( shards: Optional[list[str]] = None, ) -> tuple[KnowledgeboxFindResults, QueryParser]: # TODO: Multi hop - + # TODO: Timing using RAGMetrics + # TODO: Exception handling # 1. Get relations from entities in query # TODO: Send flag to predict entities to use DA entities once available - relations = await get_relations_results( + # TODO: Set this as an optional mode + # relations = await get_relations_results( + # kbid=kbid, + # text_answer=query, + # timeout=5.0, + # target_shard_replicas=shards, + # only_with_metadata=True, + # # use_da_entities=True, + # ) + suggest_result = await fuzzy_search_entities( kbid=kbid, - text_answer=query, - timeout=5.0, - target_shard_replicas=shards, - only_with_metadata=True, - # use_da_entities=True, + query=query, + show=item.show, # This show might need to be manually set + field_type_filter=item.field_type_filter, + range_creation_start=item.range_creation_start, + range_creation_end=item.range_creation_end, + range_modification_start=item.range_modification_start, + range_modification_end=item.range_modification_end, ) - """ - Relations(entities={'Ministry of Environment': EntitySubgraph(related_to=[DirectionalRelation(entity='Ordesa National Park', entity_type=, relation=, relation_label='subsidiary', direction=), DirectionalRelation(entity='National_parks_of_Spain', entity_type=, relation=, relation_label='subsidiary', direction=), DirectionalRelation(entity='cfcde433fc2a4e388360d3d32b03729f', entity_type=, relation=, relation_label='', direction=)])}) - """ - import pdb - - pdb.set_trace() + # Convert them to RelationNode in order to perform a relations query + if suggest_result.entities is not None: + relation_nodes = ( + RelationNode(ntype=RelationNode.NodeType.ENTITY, value=result.value, subtype=result.family) + for result in suggest_result.entities.entities + ) + relations = await get_relations_results_from_entities( + kbid=kbid, + entities=relation_nodes, + target_shard_replicas=suggest_result.shards, + timeout=5.0, + only_with_metadata=True, + ) + else: + relations = Relations(entities={}) - # # 2. Rank the relations and get the top_k - # TODO: Evaluate using suggest + explored_entities = {} pruned_relations = await rank_relations( relations, query, kbid, user, top_k=graph_strategy.top_k, generative_model=generative_model ) @@ -122,8 +147,6 @@ async def get_graph_results( pdb.set_trace() # 3. Get the text for the top_k relations - # XXX: We could use the location of head entity and tail entity to get the text instead - # that would result in way less text to retrieve paragraph_ids = { r.metadata.paragraph_id for r in pruned_relations.entities.values() @@ -148,7 +171,6 @@ async def get_graph_results( import pdb pdb.set_trace() - # TODO: Report using RAGMetrics return find_results, query_parser @@ -297,27 +319,13 @@ async def get_relations_results( try: predict = get_predict() detected_entities = await predict.detect_entities(kbid, text_answer) - request = SearchRequest() - request.relation_subgraph.entry_points.extend(detected_entities) - request.relation_subgraph.depth = 1 - - results: list[SearchResponse] - ( - results, - _, - _, - ) = await node_query( - kbid, - Method.SEARCH, - request, + + return await get_relations_results_from_entities( + kbid=kbid, + entities=detected_entities, target_shard_replicas=target_shard_replicas, timeout=timeout, - use_read_replica_nodes=True, - retry_on_primary=False, - ) - relations_results: list[RelationSearchResponse] = [result.relation for result in results] - return await merge_relations_results( - relations_results, request.relation_subgraph, only_with_metadata + only_with_metadata=only_with_metadata, ) except Exception as exc: capture_exception(exc) @@ -325,6 +333,37 @@ async def get_relations_results( return Relations(entities={}) +async def get_relations_results_from_entities( + *, + kbid: str, + entities: Iterable[RelationNode], + target_shard_replicas: Optional[list[str]], + timeout: Optional[float] = None, + only_with_metadata: bool = False, +) -> Relations: + request = SearchRequest() + request.relation_subgraph.entry_points.extend(entities) + request.relation_subgraph.depth = 1 + results: list[SearchResponse] + ( + results, + _, + _, + ) = await node_query( + kbid, + Method.SEARCH, + request, + target_shard_replicas=target_shard_replicas, + timeout=timeout, + use_read_replica_nodes=True, + retry_on_primary=False, + ) + relations_results: list[RelationSearchResponse] = [result.relation for result in results] + return await merge_relations_results( + relations_results, request.relation_subgraph, only_with_metadata + ) + + def maybe_audit_chat( *, kbid: str, diff --git a/nucliadb/src/nucliadb/search/search/graph_strategy.py b/nucliadb/src/nucliadb/search/search/graph_strategy.py index d9b2001630..74de60036c 100644 --- a/nucliadb/src/nucliadb/search/search/graph_strategy.py +++ b/nucliadb/src/nucliadb/search/search/graph_strategy.py @@ -1,6 +1,27 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import asyncio import heapq import json from collections import defaultdict +from datetime import datetime from typing import Iterable, Optional from nuclia_models.predict.generative_responses import ( @@ -9,12 +30,14 @@ StatusGenerativeResponse, ) +from nucliadb.search.requesters.utils import Method, node_query from nucliadb.search.search.find_merge import ( compose_find_resources, hydrate_and_rerank, paragraph_id_to_text_block_matches, ) from nucliadb.search.search.hydrator import ResourceHydrationOptions, TextBlockHydrationOptions +from nucliadb.search.search.merge import merge_suggest_results from nucliadb.search.search.rerankers import Reranker, RerankingOptions from nucliadb.search.utilities import get_predict from nucliadb_models.common import FieldTypeName @@ -24,11 +47,13 @@ DirectionalRelation, EntitySubgraph, KnowledgeboxFindResults, + KnowledgeboxSuggestResults, RelationDirection, Relations, ResourceProperties, UserPrompt, ) +from nucliadb_protos import nodereader_pb2 SCHEMA = { "title": "score_triplets", @@ -243,6 +268,51 @@ """ +async def fuzzy_search_entities( + kbid: str, + query: str, + show: list[ResourceProperties], + field_type_filter: list[FieldTypeName], + range_creation_start: Optional[datetime] = None, + range_creation_end: Optional[datetime] = None, + range_modification_start: Optional[datetime] = None, + range_modification_end: Optional[datetime] = None, +) -> KnowledgeboxSuggestResults: + """Fuzzy find entities in KB given a query using the same methodology as /suggest, but split by words.""" + + base_request = nodereader_pb2.SuggestRequest( + body="", features=[nodereader_pb2.SuggestFeatures.ENTITIES] + ) + if range_creation_start is not None: + base_request.timestamps.from_created.FromDatetime(range_creation_start) + if range_creation_end is not None: + base_request.timestamps.to_created.FromDatetime(range_creation_end) + if range_modification_start is not None: + base_request.timestamps.from_modified.FromDatetime(range_modification_start) + if range_modification_end is not None: + base_request.timestamps.to_modified.FromDatetime(range_modification_end) + + tasks = [] + # XXX: Splitting by words is not ideal, in the future, modify suggest to better handle this + for word in query.split(): + if len(word) <= 3: + continue + request = nodereader_pb2.SuggestRequest() + request.CopyFrom(base_request) + request.body = word + tasks.append(node_query(kbid, Method.SUGGEST, request)) + + # Gather + # TODO: What do I do with `incomplete_results`? + results_raw = await asyncio.gather(*tasks) + return await merge_suggest_results( + [item for r in results_raw for item in r[0]], + kbid=kbid, + show=show, + field_type_filter=field_type_filter, + ) + + async def rank_relations( relations: Relations, query: str, diff --git a/nucliadb/src/nucliadb/search/search/merge.py b/nucliadb/src/nucliadb/search/search/merge.py index 86b58b0c18..c2cd44363c 100644 --- a/nucliadb/src/nucliadb/search/search/merge.py +++ b/nucliadb/src/nucliadb/search/search/merge.py @@ -23,6 +23,7 @@ from typing import Any, Optional, Set, Union from nucliadb.common.ids import FieldId, ParagraphId +from nucliadb.common.models_utils import from_proto from nucliadb.common.models_utils.from_proto import RelationTypePbMap from nucliadb.search.search import cache from nucliadb.search.search.cut import cut_page @@ -34,7 +35,6 @@ ) from nucliadb_models.common import FieldTypeName from nucliadb_models.labels import translate_system_to_alias_label -from nucliadb_models.metadata import RelationMetadata from nucliadb_models.resource import ExtractedDataTypeName from nucliadb_models.search import ( DirectionalRelation, @@ -477,7 +477,7 @@ def _merge_relations_results( relation=relation_type, relation_label=relation_label, direction=RelationDirection.OUT, - metadata=RelationMetadata.from_message(metadata) if metadata else None, + metadata=from_proto.relation_metadata(metadata) if metadata else None, ) ) elif (not only_with_metadata or metadata) and destination.value in relations.entities: @@ -488,7 +488,7 @@ def _merge_relations_results( relation=relation_type, relation_label=relation_label, direction=RelationDirection.IN, - metadata=RelationMetadata.from_message(metadata) if metadata else None, + metadata=from_proto.relation_metadata(metadata) if metadata else None, ) ) From ddb0d654d609631ad4bd06f9cbe945c84c0a8430 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Mon, 13 Jan 2025 14:59:01 +0100 Subject: [PATCH 07/17] Multi hop --- nucliadb/src/nucliadb/common/ids.py | 2 +- .../src/nucliadb/search/search/chat/query.py | 65 ++++++-- .../nucliadb/search/search/graph_strategy.py | 75 +++++++-- nucliadb/src/nucliadb/search/search/merge.py | 3 +- .../tests/search/unit/test_graph_strategy.py | 144 ++++++++++++++++++ nucliadb_models/src/nucliadb_models/search.py | 10 +- 6 files changed, 269 insertions(+), 30 deletions(-) create mode 100644 nucliadb/tests/search/unit/test_graph_strategy.py diff --git a/nucliadb/src/nucliadb/common/ids.py b/nucliadb/src/nucliadb/common/ids.py index ff35901a0d..404dd978c4 100644 --- a/nucliadb/src/nucliadb/common/ids.py +++ b/nucliadb/src/nucliadb/common/ids.py @@ -132,7 +132,7 @@ def parse_field_type(cls, _type: str) -> str: # XXX: This is to support field types that are integer values of FieldType # Which is how legacy processor relations reported the paragraph_id try: - type_pb = FieldType.ValueType(_type) + type_pb = FieldType.ValueType(int(_type)) except ValueError: raise ValueError(f"Invalid FieldId: {_type}") if type_pb in FIELD_TYPE_PB_TO_STR: diff --git a/nucliadb/src/nucliadb/search/search/chat/query.py b/nucliadb/src/nucliadb/search/search/chat/query.py index d43fef516d..dc7e5cc245 100644 --- a/nucliadb/src/nucliadb/search/search/chat/query.py +++ b/nucliadb/src/nucliadb/search/search/chat/query.py @@ -29,6 +29,7 @@ from nucliadb.search.search.find import find, query_parser_from_find_request from nucliadb.search.search.graph_strategy import ( build_graph_response, + filter_subgraph, fuzzy_search_entities, rank_relations, ) @@ -96,7 +97,6 @@ async def get_graph_results( metrics: RAGMetrics = RAGMetrics(), shards: Optional[list[str]] = None, ) -> tuple[KnowledgeboxFindResults, QueryParser]: - # TODO: Multi hop # TODO: Timing using RAGMetrics # TODO: Exception handling # 1. Get relations from entities in query @@ -119,6 +119,7 @@ async def get_graph_results( range_creation_end=item.range_creation_end, range_modification_start=item.range_modification_start, range_modification_end=item.range_modification_end, + target_shard_replicas=shards, ) # Convert them to RelationNode in order to perform a relations query if suggest_result.entities is not None: @@ -135,22 +136,67 @@ async def get_graph_results( ) else: relations = Relations(entities={}) + # TODO: Apply process_subgraph to the relations + + explored_entities = set(relations.entities.keys()) # 2. Rank the relations and get the top_k - explored_entities = {} - pruned_relations = await rank_relations( + # TODO: Add upper bound to the number of entities to explore for safety + relations = await rank_relations( relations, query, kbid, user, top_k=graph_strategy.top_k, generative_model=generative_model ) - import pdb + for hop in range(graph_strategy.hops - 1): + entities_to_explore: list[RelationNode] = [] + # Find neighbors of the pruned relations and remove the ones already explored + for subgraph in relations.entities.values(): + for relation in subgraph.related_to: + if relation.entity not in explored_entities: + entities_to_explore.append( + RelationNode( + ntype=RelationNode.NodeType.ENTITY, + value=relation.entity, + subtype=relation.entity_subtype, + ) + ) + + # Get the relations for the new entities + new_relations = await get_relations_results_from_entities( + kbid=kbid, + entities=entities_to_explore, + target_shard_replicas=shards, + timeout=5.0, + only_with_metadata=True, + ) - pdb.set_trace() + # Removing the relations connected to the entities that were already explored + # XXX: This could be optimized by implementing a filter in the index + # so we don't have to remove them after + new_subgraphs = { + entity: filter_subgraph(subgraph, explored_entities) + for entity, subgraph in new_relations.entities.items() + } + if not new_subgraphs or any(not subgraph.related_to for subgraph in new_subgraphs.values()): + break + + explored_entities.update(new_subgraphs.keys()) + relations.entities.update(new_subgraphs) + + # Rank the new relations + relations = await rank_relations( + relations, + query, + kbid, + user, + top_k=graph_strategy.top_k, + generative_model=generative_model, + ) # 3. Get the text for the top_k relations paragraph_ids = { r.metadata.paragraph_id - for r in pruned_relations.entities.values() - for r in r.related_to + for rel in relations.entities.values() + for r in rel.related_to if r.metadata and r.metadata.paragraph_id } find_request = find_request_from_ask_request(item, query) @@ -161,16 +207,13 @@ async def get_graph_results( paragraph_ids, kbid=kbid, query=query, - final_relations=pruned_relations, + final_relations=relations, top_k=graph_strategy.top_k, reranker=reranker, show=find_request.show, extracted=find_request.extracted, field_type_filter=find_request.field_type_filter, ) - import pdb - - pdb.set_trace() return find_results, query_parser diff --git a/nucliadb/src/nucliadb/search/search/graph_strategy.py b/nucliadb/src/nucliadb/search/search/graph_strategy.py index 74de60036c..a31ed77884 100644 --- a/nucliadb/src/nucliadb/search/search/graph_strategy.py +++ b/nucliadb/src/nucliadb/search/search/graph_strategy.py @@ -22,7 +22,7 @@ import json from collections import defaultdict from datetime import datetime -from typing import Iterable, Optional +from typing import Any, Collection, Iterable, Optional, Union from nuclia_models.predict.generative_responses import ( JSONGenerativeResponse, @@ -90,7 +90,7 @@ } PROMPT = """\ -You are an advanced language model assisting in scoring relationships (edges) between two entities in a knowledge graph, given a user’s question. +You are an advanced language model assisting in scoring relationships (edges) between two entities in a knowledge graph, given a user’s question. For each provided **(head_entity, relationship, tail_entity)**, you must: 1. Assign a **relevance score** between **0** and **10**. @@ -98,7 +98,7 @@ 3. **10** means “this relationship is extremely relevant to the question.” 4. You may use **any integer** between 0 and 10 (e.g., 3, 7, etc.) based on how relevant you deem the relationship to be. 5. **Language Agnosticism**: The question and the relationships may be in different languages. The relevance scoring should still work and be agnostic of the language. -6. Relationships that may not answer the question directly but expand knowledge in a relevant way, should also be scored positively. +6. Relationships that may not answer the question directly but expand knowledge in a relevant way, should also be scored positively. Once you have decided the best score for each triplet, return these results **using a function call** in JSON format with the following rules: @@ -277,6 +277,7 @@ async def fuzzy_search_entities( range_creation_end: Optional[datetime] = None, range_modification_start: Optional[datetime] = None, range_modification_end: Optional[datetime] = None, + target_shard_replicas: Optional[list[str]] = None, ) -> KnowledgeboxSuggestResults: """Fuzzy find entities in KB given a query using the same methodology as /suggest, but split by words.""" @@ -300,7 +301,9 @@ async def fuzzy_search_entities( request = nodereader_pb2.SuggestRequest() request.CopyFrom(base_request) request.body = word - tasks.append(node_query(kbid, Method.SUGGEST, request)) + tasks.append( + node_query(kbid, Method.SUGGEST, request, target_shard_replicas=target_shard_replicas) + ) # Gather # TODO: What do I do with `incomplete_results`? @@ -320,6 +323,7 @@ async def rank_relations( user: str, top_k: int, generative_model: Optional[str] = None, + score_threshold: int = 0, ) -> Relations: # Store the index for keeping track after scoring flat_rels: list[tuple[str, int, DirectionalRelation]] = [ @@ -341,14 +345,25 @@ async def rank_relations( } for (ent, _, rel) in flat_rels ] + # Dedupe triplets so that they get evaluated once, we will re-associate the scores later + triplet_to_orig_indices: dict[tuple[str, str, str], list[int]] = {} + unique_triplets = [] + + for i, t in enumerate(triplets): + key = (t["head_entity"], t["relationship"], t["tail_entity"]) + if key not in triplet_to_orig_indices: + triplet_to_orig_indices[key] = [] + unique_triplets.append(t) + triplet_to_orig_indices[key].append(i) + data = { "question": query, - "triplets": triplets, + "triplets": unique_triplets, } prompt = PROMPT + json.dumps(data, indent=4) predict = get_predict() - item = ChatModel( + chat_model = ChatModel( question=prompt, user_id=user, json_schema=SCHEMA, @@ -360,10 +375,10 @@ async def rank_relations( generative_model=generative_model, ) # TODO: Enclose this in a try-except block - ident, model, answer_stream = await predict.chat_query_ndjson(kbid, item) + ident, model, answer_stream = await predict.chat_query_ndjson(kbid, chat_model) response_json = None status = None - meta = None + _ = None async for generative_chunk in answer_stream: item = generative_chunk.chunk @@ -372,7 +387,7 @@ async def rank_relations( elif isinstance(item, StatusGenerativeResponse): status = item elif isinstance(item, MetaGenerativeResponse): - meta = item + _ = item else: # TODO: Improve for logging raise ValueError(f"Unknown generative chunk type: {item}") @@ -382,17 +397,37 @@ async def rank_relations( if response_json is None or status is None or status.code != "0": raise ValueError("No JSON response found") - scored_triplets = response_json.object["triplets"] + scored_unique_triplets: list[dict[str, Union[str, Any]]] = response_json.object["triplets"] - if len(scored_triplets) != len(flat_rels): + if len(scored_unique_triplets) != len(unique_triplets): raise ValueError("Mismatch between input and output triplets") - scores = ((idx, scored_triplet["score"]) for (idx, scored_triplet) in enumerate(scored_triplets)) + + # Re-expand model scores to the original triplets + scored_triplets: list[Optional[dict[str, Any]]] = [None] * len(triplets) + for unique_idx, scored_t in enumerate(scored_unique_triplets): + h, r, ta = ( + scored_t["head_entity"], + scored_t["relationship"], + scored_t["tail_entity"], + ) + for orig_idx in triplet_to_orig_indices[(h, r, ta)]: + scored_triplets[orig_idx] = scored_t + + if any(st is None for st in scored_triplets): + raise ValueError("Some triplets did not get a score assigned") + + if len(scored_triplets) != len(flat_rels): + raise ValueError("Mismatch between input and output triplets after expansion") + + scores = ((idx, t["score"]) for (idx, t) in enumerate(scored_triplets) if t is not None) + top_k_scores = heapq.nlargest(top_k, scores, key=lambda x: x[1]) - top_k_rels = defaultdict(lambda: EntitySubgraph(related_to=[])) - for idx_flat, _ in top_k_scores: + top_k_rels: dict[str, EntitySubgraph] = defaultdict(lambda: EntitySubgraph(related_to=[])) + for idx_flat, score in top_k_scores: (ent, idx, _) = flat_rels[idx_flat] rel = relations.entities[ent].related_to[idx] - top_k_rels[ent].related_to.append(rel) + if score > score_threshold: + top_k_rels[ent].related_to.append(rel) return Relations(entities=top_k_rels) @@ -438,3 +473,13 @@ async def build_graph_response( relations=final_relations, total=len(text_blocks), ) + + +def filter_subgraph(subgraph: EntitySubgraph, entities_to_remove: Collection[str]) -> EntitySubgraph: + """ + Removes the relationships with entities in `entities_to_remove` from the subgraph. + """ + return EntitySubgraph( + # TODO: Limit to 150 is temporary, remove it and add a reranker scoring? + related_to=[rel for rel in subgraph.related_to if rel.entity not in entities_to_remove][:150] + ) diff --git a/nucliadb/src/nucliadb/search/search/merge.py b/nucliadb/src/nucliadb/search/search/merge.py index c2cd44363c..9eb44446ec 100644 --- a/nucliadb/src/nucliadb/search/search/merge.py +++ b/nucliadb/src/nucliadb/search/search/merge.py @@ -468,12 +468,12 @@ def _merge_relations_results( relation_type = RelationTypePbMap[relation.relation] relation_label = relation.relation_label metadata = relation.metadata if relation.HasField("metadata") else None - if (not only_with_metadata or metadata) and origin.value in relations.entities: relations.entities[origin.value].related_to.append( DirectionalRelation( entity=destination.value, entity_type=relation_node_type_to_entity_type(destination.ntype), + entity_subtype=destination.subtype, relation=relation_type, relation_label=relation_label, direction=RelationDirection.OUT, @@ -485,6 +485,7 @@ def _merge_relations_results( DirectionalRelation( entity=origin.value, entity_type=relation_node_type_to_entity_type(origin.ntype), + entity_subtype=origin.subtype, relation=relation_type, relation_label=relation_label, direction=RelationDirection.IN, diff --git a/nucliadb/tests/search/unit/test_graph_strategy.py b/nucliadb/tests/search/unit/test_graph_strategy.py new file mode 100644 index 0000000000..e5ef70bb11 --- /dev/null +++ b/nucliadb/tests/search/unit/test_graph_strategy.py @@ -0,0 +1,144 @@ +# Copyright (C) 2021 Bosutech XXI S.L. +# +# nucliadb is offered under the AGPL v3.0 and as commercial software. +# For commercial licensing, contact us at info@nuclia.com. +# +# AGPL: +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from nuclia_models.predict.generative_responses import ( + JSONGenerativeResponse, + MetaGenerativeResponse, + StatusGenerativeResponse, +) + +from nucliadb.search.search.graph_strategy import rank_relations +from nucliadb_models.metadata import RelationType +from nucliadb_models.search import ( + DirectionalRelation, + EntitySubgraph, + EntityType, + RelationDirection, + Relations, +) + + +@pytest.mark.asyncio +@patch("nucliadb.search.search.graph_strategy.get_predict") +async def test_rank_relations( + mocker, +): + mock_json_response = { + "triplets": [ + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "has product", + "tail_entity": "Socks", + "score": 2, + }, + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "has product", + "tail_entity": "Titanium Grade 3 Roofing Nails", + "score": 10, + }, + { + "head_entity": "John Adams Roofing Inc.", + "relationship": "is located in", + "tail_entity": "New York", + "score": 6, + }, + { + "head_entity": "John Adams", + "relationship": "married to", + "tail_entity": "Abigail Adams", + "score": 0, + }, + ] + } + chat_mock = AsyncMock() + chat_mock.__aiter__.return_value = iter( + [ + MagicMock(chunk=JSONGenerativeResponse(object=mock_json_response)), + MagicMock(chunk=StatusGenerativeResponse(code="0")), + MagicMock(chunk=MetaGenerativeResponse(input_tokens=10, output_tokens=5, timings={})), + ] + ) + predict_mock = AsyncMock() + predict_mock.chat_query_ndjson.return_value = ("my_ident", "fake_llm", chat_mock) + mocker.return_value = predict_mock + relations = Relations( + entities={ + "John Adams Roofing Inc.": EntitySubgraph( + related_to=[ + DirectionalRelation( + entity="Socks", + entity_type=EntityType.ENTITY, + entity_subtype="PRODUCT", + relation_label="has product", + relation=RelationType.ENTITY, + direction=RelationDirection.OUT, + ), + DirectionalRelation( + entity="Titanium Grade 3 Roofing Nails", + entity_type=EntityType.ENTITY, + entity_subtype="PRODUCT", + relation_label="has product", + relation=RelationType.ENTITY, + direction=RelationDirection.OUT, + ), + DirectionalRelation( + entity="New York", + entity_type=EntityType.ENTITY, + entity_subtype="LOCATION", + relation_label="is located in", + relation=RelationType.ENTITY, + direction=RelationDirection.OUT, + ), + DirectionalRelation( + entity="New York", + entity_type=EntityType.ENTITY, + entity_subtype="PLACE", + relation_label="is located in", + relation=RelationType.ENTITY, + direction=RelationDirection.OUT, + ), + ] + ), + "John Adams": EntitySubgraph( + related_to=[ + DirectionalRelation( + entity="Abigail Adams", + entity_type=EntityType.ENTITY, + entity_subtype="PERSON", + relation_label="married to", + relation=RelationType.ENTITY, + direction=RelationDirection.OUT, + ) + ] + ), + } + ) + result = await rank_relations( + relations=relations, query="my_query", kbid="my_kbid", user="my_user", top_k=3, score_threshold=0 + ) + assert "John Adams Roofing Inc." in result.entities + assert ( + result.entities["John Adams Roofing Inc."].related_to + == relations.entities["John Adams Roofing Inc."].related_to[1:] + ) + assert "John Adams" not in result.entities diff --git a/nucliadb_models/src/nucliadb_models/search.py b/nucliadb_models/src/nucliadb_models/search.py index ac0ca579bc..03a8664f00 100644 --- a/nucliadb_models/src/nucliadb_models/search.py +++ b/nucliadb_models/src/nucliadb_models/search.py @@ -238,6 +238,7 @@ class EntityType(str, Enum): class DirectionalRelation(BaseModel): entity: str entity_type: EntityType + entity_subtype: str relation: RelationType relation_label: str direction: RelationDirection @@ -1221,10 +1222,15 @@ class GraphStrategy(RagStrategy): """ name: Literal["graph"] = "graph" - n_hops: int = Field( + hops: int = Field( default=1, title="Number of hops", - description="Number of hops to take when exploring the graph for relevant context. Biggers values will take more time to compute .", + description="""Number of hops to take when exploring the graph for relevant context. +For example, +- hops=1 will explore the neighbors of the starting entities. +- hops=2 will explore the neighbors of the neighbors of the starting entities. +And so on. +Bigger values will discover more intricate relationships but will also take more time to compute.""", ge=1, ) top_k: int = Field( From 65865dd50936435417e2482dbc5dc640792c8e0c Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Mon, 13 Jan 2025 15:52:33 +0100 Subject: [PATCH 08/17] Refactor + fix tests --- .../src/nucliadb/search/search/chat/ask.py | 2 +- .../src/nucliadb/search/search/chat/query.py | 143 +---------------- .../nucliadb/search/search/graph_strategy.py | 146 ++++++++++++++++++ .../integration/search/test_search.py | 38 ++++- nucliadb_sdk/tests/test_ask.py | 2 + 5 files changed, 186 insertions(+), 145 deletions(-) diff --git a/nucliadb/src/nucliadb/search/search/chat/ask.py b/nucliadb/src/nucliadb/search/search/chat/ask.py index 1647f60d1f..4cbc6cd0ed 100644 --- a/nucliadb/src/nucliadb/search/search/chat/ask.py +++ b/nucliadb/src/nucliadb/search/search/chat/ask.py @@ -48,7 +48,6 @@ NOT_ENOUGH_CONTEXT_ANSWER, ChatAuditor, get_find_results, - get_graph_results, get_relations_results, rephrase_query, sorted_prompt_context_list, @@ -58,6 +57,7 @@ IncompleteFindResultsError, InvalidQueryError, ) +from nucliadb.search.search.graph_strategy import get_graph_results from nucliadb.search.search.metrics import RAGMetrics from nucliadb.search.search.query import QueryParser from nucliadb.search.utilities import get_predict diff --git a/nucliadb/src/nucliadb/search/search/chat/query.py b/nucliadb/src/nucliadb/search/search/chat/query.py index dc7e5cc245..12f5a927ce 100644 --- a/nucliadb/src/nucliadb/search/search/chat/query.py +++ b/nucliadb/src/nucliadb/search/search/chat/query.py @@ -26,13 +26,7 @@ from nucliadb.search.requesters.utils import Method, node_query from nucliadb.search.search.chat.exceptions import NoRetrievalResultsError from nucliadb.search.search.exceptions import IncompleteFindResultsError -from nucliadb.search.search.find import find, query_parser_from_find_request -from nucliadb.search.search.graph_strategy import ( - build_graph_response, - filter_subgraph, - fuzzy_search_entities, - rank_relations, -) +from nucliadb.search.search.find import find from nucliadb.search.search.merge import merge_relations_results from nucliadb.search.search.metrics import RAGMetrics from nucliadb.search.search.query import QueryParser @@ -43,7 +37,6 @@ ChatContextMessage, ChatOptions, FindRequest, - GraphStrategy, KnowledgeboxFindResults, NucliaDBClientType, PreQueriesStrategy, @@ -84,140 +77,6 @@ async def rephrase_query( return await predict.rephrase_query(kbid, req) -async def get_graph_results( - *, - kbid: str, - query: str, - item: AskRequest, - ndb_client: NucliaDBClientType, - user: str, - origin: str, - graph_strategy: GraphStrategy, - generative_model: Optional[str] = None, - metrics: RAGMetrics = RAGMetrics(), - shards: Optional[list[str]] = None, -) -> tuple[KnowledgeboxFindResults, QueryParser]: - # TODO: Timing using RAGMetrics - # TODO: Exception handling - # 1. Get relations from entities in query - # TODO: Send flag to predict entities to use DA entities once available - # TODO: Set this as an optional mode - # relations = await get_relations_results( - # kbid=kbid, - # text_answer=query, - # timeout=5.0, - # target_shard_replicas=shards, - # only_with_metadata=True, - # # use_da_entities=True, - # ) - suggest_result = await fuzzy_search_entities( - kbid=kbid, - query=query, - show=item.show, # This show might need to be manually set - field_type_filter=item.field_type_filter, - range_creation_start=item.range_creation_start, - range_creation_end=item.range_creation_end, - range_modification_start=item.range_modification_start, - range_modification_end=item.range_modification_end, - target_shard_replicas=shards, - ) - # Convert them to RelationNode in order to perform a relations query - if suggest_result.entities is not None: - relation_nodes = ( - RelationNode(ntype=RelationNode.NodeType.ENTITY, value=result.value, subtype=result.family) - for result in suggest_result.entities.entities - ) - relations = await get_relations_results_from_entities( - kbid=kbid, - entities=relation_nodes, - target_shard_replicas=suggest_result.shards, - timeout=5.0, - only_with_metadata=True, - ) - else: - relations = Relations(entities={}) - # TODO: Apply process_subgraph to the relations - - explored_entities = set(relations.entities.keys()) - - # 2. Rank the relations and get the top_k - # TODO: Add upper bound to the number of entities to explore for safety - relations = await rank_relations( - relations, query, kbid, user, top_k=graph_strategy.top_k, generative_model=generative_model - ) - - for hop in range(graph_strategy.hops - 1): - entities_to_explore: list[RelationNode] = [] - # Find neighbors of the pruned relations and remove the ones already explored - for subgraph in relations.entities.values(): - for relation in subgraph.related_to: - if relation.entity not in explored_entities: - entities_to_explore.append( - RelationNode( - ntype=RelationNode.NodeType.ENTITY, - value=relation.entity, - subtype=relation.entity_subtype, - ) - ) - - # Get the relations for the new entities - new_relations = await get_relations_results_from_entities( - kbid=kbid, - entities=entities_to_explore, - target_shard_replicas=shards, - timeout=5.0, - only_with_metadata=True, - ) - - # Removing the relations connected to the entities that were already explored - # XXX: This could be optimized by implementing a filter in the index - # so we don't have to remove them after - new_subgraphs = { - entity: filter_subgraph(subgraph, explored_entities) - for entity, subgraph in new_relations.entities.items() - } - if not new_subgraphs or any(not subgraph.related_to for subgraph in new_subgraphs.values()): - break - - explored_entities.update(new_subgraphs.keys()) - relations.entities.update(new_subgraphs) - - # Rank the new relations - relations = await rank_relations( - relations, - query, - kbid, - user, - top_k=graph_strategy.top_k, - generative_model=generative_model, - ) - - # 3. Get the text for the top_k relations - paragraph_ids = { - r.metadata.paragraph_id - for rel in relations.entities.values() - for r in rel.related_to - if r.metadata and r.metadata.paragraph_id - } - find_request = find_request_from_ask_request(item, query) - query_parser, rank_fusion, reranker = await query_parser_from_find_request( - kbid, find_request, generative_model=generative_model - ) - find_results = await build_graph_response( - paragraph_ids, - kbid=kbid, - query=query, - final_relations=relations, - top_k=graph_strategy.top_k, - reranker=reranker, - show=find_request.show, - extracted=find_request.extracted, - field_type_filter=find_request.field_type_filter, - ) - - return find_results, query_parser - - async def get_find_results( *, kbid: str, diff --git a/nucliadb/src/nucliadb/search/search/graph_strategy.py b/nucliadb/src/nucliadb/search/search/graph_strategy.py index a31ed77884..e0e71aab84 100644 --- a/nucliadb/src/nucliadb/search/search/graph_strategy.py +++ b/nucliadb/src/nucliadb/search/search/graph_strategy.py @@ -31,6 +31,11 @@ ) from nucliadb.search.requesters.utils import Method, node_query +from nucliadb.search.search.chat.query import ( + find_request_from_ask_request, + get_relations_results_from_entities, +) +from nucliadb.search.search.find import query_parser_from_find_request from nucliadb.search.search.find_merge import ( compose_find_resources, hydrate_and_rerank, @@ -38,22 +43,28 @@ ) from nucliadb.search.search.hydrator import ResourceHydrationOptions, TextBlockHydrationOptions from nucliadb.search.search.merge import merge_suggest_results +from nucliadb.search.search.metrics import RAGMetrics +from nucliadb.search.search.query import QueryParser from nucliadb.search.search.rerankers import Reranker, RerankingOptions from nucliadb.search.utilities import get_predict from nucliadb_models.common import FieldTypeName from nucliadb_models.resource import ExtractedDataTypeName from nucliadb_models.search import ( + AskRequest, ChatModel, DirectionalRelation, EntitySubgraph, + GraphStrategy, KnowledgeboxFindResults, KnowledgeboxSuggestResults, + NucliaDBClientType, RelationDirection, Relations, ResourceProperties, UserPrompt, ) from nucliadb_protos import nodereader_pb2 +from nucliadb_protos.utils_pb2 import RelationNode SCHEMA = { "title": "score_triplets", @@ -268,6 +279,141 @@ """ +async def get_graph_results( + *, + kbid: str, + query: str, + item: AskRequest, + ndb_client: NucliaDBClientType, + user: str, + origin: str, + graph_strategy: GraphStrategy, + generative_model: Optional[str] = None, + metrics: RAGMetrics = RAGMetrics(), + shards: Optional[list[str]] = None, +) -> tuple[KnowledgeboxFindResults, QueryParser]: + # TODO: Timing using RAGMetrics + # TODO: Exception handling + # 1. Get relations from entities in query + # TODO: Send flag to predict entities to use DA entities once available + # TODO: Set this as an optional mode + # relations = await get_relations_results( + # kbid=kbid, + # text_answer=query, + # timeout=5.0, + # target_shard_replicas=shards, + # only_with_metadata=True, + # # use_da_entities=True, + # ) + suggest_result = await fuzzy_search_entities( + kbid=kbid, + query=query, + show=[], + field_type_filter=item.field_type_filter, + range_creation_start=item.range_creation_start, + range_creation_end=item.range_creation_end, + range_modification_start=item.range_modification_start, + range_modification_end=item.range_modification_end, + target_shard_replicas=shards, + ) + + # Convert them to RelationNode in order to perform a relations query + if suggest_result.entities is not None: + relation_nodes = ( + RelationNode(ntype=RelationNode.NodeType.ENTITY, value=result.value, subtype=result.family) + for result in suggest_result.entities.entities + ) + relations = await get_relations_results_from_entities( + kbid=kbid, + entities=relation_nodes, + target_shard_replicas=shards, + timeout=5.0, + only_with_metadata=True, + ) + else: + relations = Relations(entities={}) + # TODO: Apply process_subgraph to the relations + + explored_entities = set(relations.entities.keys()) + + # 2. Rank the relations and get the top_k + # TODO: Add upper bound to the number of entities to explore for safety + relations = await rank_relations( + relations, query, kbid, user, top_k=graph_strategy.top_k, generative_model=generative_model + ) + + for hop in range(graph_strategy.hops - 1): + entities_to_explore: list[RelationNode] = [] + # Find neighbors of the pruned relations and remove the ones already explored + for subgraph in relations.entities.values(): + for relation in subgraph.related_to: + if relation.entity not in explored_entities: + entities_to_explore.append( + RelationNode( + ntype=RelationNode.NodeType.ENTITY, + value=relation.entity, + subtype=relation.entity_subtype, + ) + ) + + # Get the relations for the new entities + new_relations = await get_relations_results_from_entities( + kbid=kbid, + entities=entities_to_explore, + target_shard_replicas=shards, + timeout=5.0, + only_with_metadata=True, + ) + + # Removing the relations connected to the entities that were already explored + # XXX: This could be optimized by implementing a filter in the index + # so we don't have to remove them after + new_subgraphs = { + entity: filter_subgraph(subgraph, explored_entities) + for entity, subgraph in new_relations.entities.items() + } + if not new_subgraphs or any(not subgraph.related_to for subgraph in new_subgraphs.values()): + break + + explored_entities.update(new_subgraphs.keys()) + relations.entities.update(new_subgraphs) + + # Rank the new relations + relations = await rank_relations( + relations, + query, + kbid, + user, + top_k=graph_strategy.top_k, + generative_model=generative_model, + ) + + # 3. Get the text for the top_k relations + paragraph_ids = { + r.metadata.paragraph_id + for rel in relations.entities.values() + for r in rel.related_to + if r.metadata and r.metadata.paragraph_id + } + find_request = find_request_from_ask_request(item, query) + query_parser, rank_fusion, reranker = await query_parser_from_find_request( + kbid, find_request, generative_model=generative_model + ) + find_results = await build_graph_response( + paragraph_ids, + kbid=kbid, + query=query, + final_relations=relations, + top_k=graph_strategy.top_k, + reranker=reranker, + show=find_request.show, + extracted=find_request.extracted, + field_type_filter=find_request.field_type_filter, + ) + + return find_results, query_parser + + async def fuzzy_search_entities( kbid: str, query: str, diff --git a/nucliadb/tests/nucliadb/integration/search/test_search.py b/nucliadb/tests/nucliadb/integration/search/test_search.py index 64364771ac..8194d3f945 100644 --- a/nucliadb/tests/nucliadb/integration/search/test_search.py +++ b/nucliadb/tests/nucliadb/integration/search/test_search.py @@ -471,6 +471,8 @@ async def test_search_relations( "relation": "ENTITY", "relation_label": "write", "direction": "out", + "entity_subtype": "", + "metadata": None, }, { "entity": "Poetry", @@ -478,6 +480,8 @@ async def test_search_relations( "relation": "ENTITY", "relation_label": "like", "direction": "out", + "entity_subtype": "", + "metadata": None, }, { "entity": "Joan Antoni", @@ -485,6 +489,8 @@ async def test_search_relations( "relation": "ENTITY", "relation_label": "read", "direction": "in", + "entity_subtype": "", + "metadata": None, }, ] }, @@ -496,6 +502,8 @@ async def test_search_relations( "relation": "ENTITY", "relation_label": "formulate", "direction": "out", + "entity_subtype": "", + "metadata": None, }, { "entity": "Physics", @@ -503,6 +511,8 @@ async def test_search_relations( "relation": "ENTITY", "relation_label": "study", "direction": "out", + "entity_subtype": "science", + "metadata": None, }, ] }, @@ -536,6 +546,8 @@ async def test_search_relations( "relation": "ENTITY", "relation_label": "species", "direction": "in", + "entity_subtype": "", + "metadata": None, }, { "entity": "Swallow", @@ -543,6 +555,8 @@ async def test_search_relations( "relation": "ENTITY", "relation_label": "species", "direction": "in", + "entity_subtype": "", + "metadata": None, }, ] }, @@ -644,6 +658,8 @@ async def test_search_automatic_relations( "relation": "COLAB", "relation_label": "", "direction": "out", + "entity_subtype": "PERSON", + "metadata": None, }, { "entity": "Anne", @@ -651,6 +667,8 @@ async def test_search_automatic_relations( "relation": "COLAB", "relation_label": "", "direction": "out", + "entity_subtype": "", + "metadata": None, }, { "entity": "John", @@ -658,6 +676,8 @@ async def test_search_automatic_relations( "relation": "COLAB", "relation_label": "", "direction": "out", + "entity_subtype": "", + "metadata": None, }, { "entity": "cat", @@ -665,6 +685,8 @@ async def test_search_automatic_relations( "relation": "ENTITY", "relation_label": "", "direction": "out", + "entity_subtype": "ANIMAL", + "metadata": None, }, { "entity": "label", @@ -672,6 +694,8 @@ async def test_search_automatic_relations( "relation": "ABOUT", "relation_label": "", "direction": "out", + "entity_subtype": "", + "metadata": None, }, { "entity": "animals/cat", @@ -679,6 +703,8 @@ async def test_search_automatic_relations( "relation": "ABOUT", "relation_label": "", "direction": "out", + "entity_subtype": "", + "metadata": None, }, { "entity": "food/cookie", @@ -686,6 +712,8 @@ async def test_search_automatic_relations( "relation": "ABOUT", "relation_label": "", "direction": "out", + "entity_subtype": "", + "metadata": None, }, { "entity": "sub-document", @@ -693,6 +721,8 @@ async def test_search_automatic_relations( "relation": "CHILD", "relation_label": "", "direction": "out", + "entity_subtype": "", + "metadata": None, }, { "entity": "other", @@ -700,6 +730,8 @@ async def test_search_automatic_relations( "relation": "OTHER", "relation_label": "", "direction": "out", + "entity_subtype": "things", + "metadata": None, }, ] } @@ -709,7 +741,7 @@ async def test_search_automatic_relations( assert entity in entities assert len(entities[entity]["related_to"]) == len(expected[entity]["related_to"]) - assert sorted(expected[entity]["related_to"], key=lambda x: x["entity"]) == sorted( + assert sorted(expected[entity]["related_to"], key=lambda x: x["entity"]) == sorted( # type: ignore entities[entity]["related_to"], key=lambda x: x["entity"] ) @@ -739,6 +771,8 @@ async def test_search_automatic_relations( "relation": "COLAB", "relation_label": "", "direction": "in", + "entity_subtype": "", + "metadata": None, } ] } @@ -748,7 +782,7 @@ async def test_search_automatic_relations( assert entity in entities assert len(entities[entity]["related_to"]) == len(expected[entity]["related_to"]) - assert sorted(expected[entity]["related_to"], key=lambda x: x["entity"]) == sorted( + assert sorted(expected[entity]["related_to"], key=lambda x: x["entity"]) == sorted( # type: ignore entities[entity]["related_to"], key=lambda x: x["entity"] ) diff --git a/nucliadb_sdk/tests/test_ask.py b/nucliadb_sdk/tests/test_ask.py index 9c3ca61a1d..38356b02a8 100644 --- a/nucliadb_sdk/tests/test_ask.py +++ b/nucliadb_sdk/tests/test_ask.py @@ -115,6 +115,7 @@ def test_ask_response_parser_stream(): DirectionalRelation( entity="Semantic Search", entity_type=EntityType.ENTITY, + entity_subtype="concept", relation=RelationType.ABOUT, relation_label="performing", direction=RelationDirection.OUT, @@ -145,6 +146,7 @@ def test_ask_response_parser_stream(): assert ask_response.answer == "This is your Nuclia answer." assert ask_response.status == "success" assert ask_response.relations.entities["Nuclia"].related_to[0].entity == "Semantic Search" + assert ask_response.relations.entities["Nuclia"].related_to[0].entity_subtype == "concept" assert ask_response.citations["some/paragraph/id"] == "This is a citation" assert ask_response.retrieval_results.resources == {} assert ask_response.metadata.tokens.input == 10 From 098fdffdd786d8dfe6e2d5a71a7802ef6b64bd54 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Mon, 13 Jan 2025 17:20:31 +0100 Subject: [PATCH 09/17] Refactor --- .../nucliadb/search/search/graph_strategy.py | 233 +++++++++--------- 1 file changed, 114 insertions(+), 119 deletions(-) diff --git a/nucliadb/src/nucliadb/search/search/graph_strategy.py b/nucliadb/src/nucliadb/search/search/graph_strategy.py index e0e71aab84..f2365d9c61 100644 --- a/nucliadb/src/nucliadb/search/search/graph_strategy.py +++ b/nucliadb/src/nucliadb/search/search/graph_strategy.py @@ -29,7 +29,9 @@ MetaGenerativeResponse, StatusGenerativeResponse, ) +from sentry_sdk import capture_exception +from nucliadb.search import logger from nucliadb.search.requesters.utils import Method, node_query from nucliadb.search.search.chat.query import ( find_request_from_ask_request, @@ -292,125 +294,113 @@ async def get_graph_results( metrics: RAGMetrics = RAGMetrics(), shards: Optional[list[str]] = None, ) -> tuple[KnowledgeboxFindResults, QueryParser]: - # TODO: Timing using RAGMetrics - # TODO: Exception handling - # 1. Get relations from entities in query - # TODO: Send flag to predict entities to use DA entities once available - # TODO: Set this as an optional mode - # relations = await get_relations_results( - # kbid=kbid, - # text_answer=query, - # timeout=5.0, - # target_shard_replicas=shards, - # only_with_metadata=True, - # # use_da_entities=True, - # ) - suggest_result = await fuzzy_search_entities( - kbid=kbid, - query=query, - show=[], - field_type_filter=item.field_type_filter, - range_creation_start=item.range_creation_start, - range_creation_end=item.range_creation_end, - range_modification_start=item.range_modification_start, - range_modification_end=item.range_modification_end, - target_shard_replicas=shards, - ) - - # Convert them to RelationNode in order to perform a relations query - if suggest_result.entities is not None: - relation_nodes = ( - RelationNode(ntype=RelationNode.NodeType.ENTITY, value=result.value, subtype=result.family) - for result in suggest_result.entities.entities - ) - relations = await get_relations_results_from_entities( - kbid=kbid, - entities=relation_nodes, - target_shard_replicas=shards, - timeout=5.0, - only_with_metadata=True, - ) - else: - relations = Relations(entities={}) - # TODO: Apply process_subgraph to the relations - - explored_entities = set(relations.entities.keys()) - - # 2. Rank the relations and get the top_k - # TODO: Add upper bound to the number of entities to explore for safety - relations = await rank_relations( - relations, query, kbid, user, top_k=graph_strategy.top_k, generative_model=generative_model - ) - - for hop in range(graph_strategy.hops - 1): - entities_to_explore: list[RelationNode] = [] - # Find neighbors of the pruned relations and remove the ones already explored - for subgraph in relations.entities.values(): - for relation in subgraph.related_to: - if relation.entity not in explored_entities: - entities_to_explore.append( - RelationNode( - ntype=RelationNode.NodeType.ENTITY, - value=relation.entity, - subtype=relation.entity_subtype, - ) + relations = Relations(entities={}) + explored_entities: set[str] = set() + + for hop in range(graph_strategy.hops): + entities_to_explore: Iterable[RelationNode] = [] + if hop == 0: + # Get the entities from the query + with metrics.time("graph_strat_query_entities"): + suggest_result = await fuzzy_search_entities( + kbid=kbid, + query=query, + show=[], + field_type_filter=item.field_type_filter, + range_creation_start=item.range_creation_start, + range_creation_end=item.range_creation_end, + range_modification_start=item.range_modification_start, + range_modification_end=item.range_modification_end, + target_shard_replicas=shards, + ) + + if suggest_result.entities is not None: + entities_to_explore = ( + RelationNode( + ntype=RelationNode.NodeType.ENTITY, value=result.value, subtype=result.family ) + for result in suggest_result.entities.entities + ) + else: + entities_to_explore = [] + else: + # Find neighbors of the current relations and remove the ones already explored + entities_to_explore = ( + RelationNode( + ntype=RelationNode.NodeType.ENTITY, + value=relation.entity, + subtype=relation.entity_subtype, + ) + for subgraph in relations.entities.values() + for relation in subgraph.related_to + if relation.entity not in explored_entities + ) # Get the relations for the new entities - new_relations = await get_relations_results_from_entities( - kbid=kbid, - entities=entities_to_explore, - target_shard_replicas=shards, - timeout=5.0, - only_with_metadata=True, - ) - - # Removing the relations connected to the entities that were already explored - # XXX: This could be optimized by implementing a filter in the index - # so we don't have to remove them after - new_subgraphs = { - entity: filter_subgraph(subgraph, explored_entities) - for entity, subgraph in new_relations.entities.items() + with metrics.time("graph_strat_neighbor_relations"): + try: + new_relations = await get_relations_results_from_entities( + kbid=kbid, + entities=entities_to_explore, + target_shard_replicas=shards, + timeout=5.0, + only_with_metadata=True, + ) + except Exception as e: + capture_exception(e) + logger.exception("Error in getting query relations for graph strategy") + new_relations = Relations(entities={}) + + # Removing the relations connected to the entities that were already explored + # XXX: This could be optimized by implementing a filter in the index + # so we don't have to remove them after + new_subgraphs = { + entity: filter_subgraph(subgraph, explored_entities) + for entity, subgraph in new_relations.entities.items() + } + + if not new_subgraphs or any(not subgraph.related_to for subgraph in new_subgraphs.values()): + break + + explored_entities.update(new_subgraphs.keys()) + relations.entities.update(new_subgraphs) + + # Rank the relevance of the relations + # TODO: Add upper bound to the number of entities to explore for safety + with metrics.time("graph_strat_rank_relations"): + relations = await rank_relations( + relations, + query, + kbid, + user, + top_k=graph_strategy.top_k, + generative_model=generative_model, + ) + + # Get the text blocks of the paragraphs that contain the top relations + with metrics.time("graph_strat_build_response"): + paragraph_ids = { + r.metadata.paragraph_id + for rel in relations.entities.values() + for r in rel.related_to + if r.metadata and r.metadata.paragraph_id } - if not new_subgraphs or any(not subgraph.related_to for subgraph in new_subgraphs.values()): - break - - explored_entities.update(new_subgraphs.keys()) - relations.entities.update(new_subgraphs) - - # Rank the new relations - relations = await rank_relations( - relations, - query, - kbid, - user, + find_request = find_request_from_ask_request(item, query) + query_parser, rank_fusion, reranker = await query_parser_from_find_request( + kbid, find_request, generative_model=generative_model + ) + find_results = await build_graph_response( + paragraph_ids, + kbid=kbid, + query=query, + final_relations=relations, top_k=graph_strategy.top_k, - generative_model=generative_model, + reranker=reranker, + show=find_request.show, + extracted=find_request.extracted, + field_type_filter=find_request.field_type_filter, ) - # 3. Get the text for the top_k relations - paragraph_ids = { - r.metadata.paragraph_id - for rel in relations.entities.values() - for r in rel.related_to - if r.metadata and r.metadata.paragraph_id - } - find_request = find_request_from_ask_request(item, query) - query_parser, rank_fusion, reranker = await query_parser_from_find_request( - kbid, find_request, generative_model=generative_model - ) - find_results = await build_graph_response( - paragraph_ids, - kbid=kbid, - query=query, - final_relations=relations, - top_k=graph_strategy.top_k, - reranker=reranker, - show=find_request.show, - extracted=find_request.extracted, - field_type_filter=find_request.field_type_filter, - ) - return find_results, query_parser @@ -453,13 +443,18 @@ async def fuzzy_search_entities( # Gather # TODO: What do I do with `incomplete_results`? - results_raw = await asyncio.gather(*tasks) - return await merge_suggest_results( - [item for r in results_raw for item in r[0]], - kbid=kbid, - show=show, - field_type_filter=field_type_filter, - ) + try: + results_raw = await asyncio.gather(*tasks) + return await merge_suggest_results( + [item for r in results_raw for item in r[0]], + kbid=kbid, + show=show, + field_type_filter=field_type_filter, + ) + except Exception as e: + capture_exception(e) + logger.exception("Error in finding entities in query for graph strategy") + return KnowledgeboxSuggestResults(entities=None) async def rank_relations( From 7583e11051afd108827f852fce1cec216feef96e Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Tue, 14 Jan 2025 10:43:49 +0100 Subject: [PATCH 10/17] Better testing --- .../src/nucliadb/search/search/chat/query.py | 5 +- .../nucliadb/search/search/graph_strategy.py | 33 ++-- nucliadb/src/nucliadb/search/search/merge.py | 56 ++++--- .../tests/nucliadb/integration/test_ask.py | 158 ++++++++++++++++++ .../unit/{ => search}/test_graph_strategy.py | 0 nucliadb_models/src/nucliadb_models/search.py | 11 +- 6 files changed, 223 insertions(+), 40 deletions(-) rename nucliadb/tests/search/unit/{ => search}/test_graph_strategy.py (100%) diff --git a/nucliadb/src/nucliadb/search/search/chat/query.py b/nucliadb/src/nucliadb/search/search/chat/query.py index 12f5a927ce..1701e30d42 100644 --- a/nucliadb/src/nucliadb/search/search/chat/query.py +++ b/nucliadb/src/nucliadb/search/search/chat/query.py @@ -217,6 +217,7 @@ async def get_relations_results( target_shard_replicas: Optional[list[str]], timeout: Optional[float] = None, only_with_metadata: bool = False, + only_agentic_relations: bool = False, ) -> Relations: try: predict = get_predict() @@ -228,6 +229,7 @@ async def get_relations_results( target_shard_replicas=target_shard_replicas, timeout=timeout, only_with_metadata=only_with_metadata, + only_agentic_relations=only_agentic_relations, ) except Exception as exc: capture_exception(exc) @@ -242,6 +244,7 @@ async def get_relations_results_from_entities( target_shard_replicas: Optional[list[str]], timeout: Optional[float] = None, only_with_metadata: bool = False, + only_agentic_relations: bool = False, ) -> Relations: request = SearchRequest() request.relation_subgraph.entry_points.extend(entities) @@ -262,7 +265,7 @@ async def get_relations_results_from_entities( ) relations_results: list[RelationSearchResponse] = [result.relation for result in results] return await merge_relations_results( - relations_results, request.relation_subgraph, only_with_metadata + relations_results, request.relation_subgraph, only_with_metadata, only_agentic_relations ) diff --git a/nucliadb/src/nucliadb/search/search/graph_strategy.py b/nucliadb/src/nucliadb/search/search/graph_strategy.py index f2365d9c61..5c907fa722 100644 --- a/nucliadb/src/nucliadb/search/search/graph_strategy.py +++ b/nucliadb/src/nucliadb/search/search/graph_strategy.py @@ -313,7 +313,6 @@ async def get_graph_results( range_modification_end=item.range_modification_end, target_shard_replicas=shards, ) - if suggest_result.entities is not None: entities_to_explore = ( RelationNode( @@ -345,6 +344,7 @@ async def get_graph_results( target_shard_replicas=shards, timeout=5.0, only_with_metadata=True, + only_agentic_relations=graph_strategy.agentic_graph_only, ) except Exception as e: capture_exception(e) @@ -359,7 +359,7 @@ async def get_graph_results( for entity, subgraph in new_relations.entities.items() } - if not new_subgraphs or any(not subgraph.related_to for subgraph in new_subgraphs.values()): + if not new_subgraphs or all(not subgraph.related_to for subgraph in new_subgraphs.values()): break explored_entities.update(new_subgraphs.keys()) @@ -368,14 +368,19 @@ async def get_graph_results( # Rank the relevance of the relations # TODO: Add upper bound to the number of entities to explore for safety with metrics.time("graph_strat_rank_relations"): - relations = await rank_relations( - relations, - query, - kbid, - user, - top_k=graph_strategy.top_k, - generative_model=generative_model, - ) + try: + relations = await rank_relations( + relations, + query, + kbid, + user, + top_k=graph_strategy.top_k, + generative_model=generative_model, + ) + except Exception as e: + capture_exception(e) + logger.exception("Error in ranking relations for graph strategy") + break # Get the text blocks of the paragraphs that contain the top relations with metrics.time("graph_strat_build_response"): @@ -465,13 +470,16 @@ async def rank_relations( top_k: int, generative_model: Optional[str] = None, score_threshold: int = 0, + max_rels_to_eval: int = 300, ) -> Relations: # Store the index for keeping track after scoring + # XXX: Here we set a hard limit on the number of relations to evaluate for safety and performance + # In the future we could to several iterations of scoring flat_rels: list[tuple[str, int, DirectionalRelation]] = [ (ent, idx, rel) for (ent, rels) in relations.entities.items() for (idx, rel) in enumerate(rels.related_to) - ] + ][:max_rels_to_eval] triplets: list[dict[str, str]] = [ { "head_entity": ent, @@ -621,6 +629,5 @@ def filter_subgraph(subgraph: EntitySubgraph, entities_to_remove: Collection[str Removes the relationships with entities in `entities_to_remove` from the subgraph. """ return EntitySubgraph( - # TODO: Limit to 150 is temporary, remove it and add a reranker scoring? - related_to=[rel for rel in subgraph.related_to if rel.entity not in entities_to_remove][:150] + related_to=[rel for rel in subgraph.related_to if rel.entity not in entities_to_remove] ) diff --git a/nucliadb/src/nucliadb/search/search/merge.py b/nucliadb/src/nucliadb/search/search/merge.py index 9eb44446ec..fa9f26f3ab 100644 --- a/nucliadb/src/nucliadb/search/search/merge.py +++ b/nucliadb/src/nucliadb/search/search/merge.py @@ -444,10 +444,16 @@ async def merge_relations_results( relations_responses: list[RelationSearchResponse], query: EntitiesSubgraphRequest, only_with_metadata: bool = False, + only_agentic: bool = False, ) -> Relations: loop = asyncio.get_event_loop() return await loop.run_in_executor( - None, _merge_relations_results, relations_responses, query, only_with_metadata + None, + _merge_relations_results, + relations_responses, + query, + only_with_metadata, + only_agentic, ) @@ -455,6 +461,7 @@ def _merge_relations_results( relations_responses: list[RelationSearchResponse], query: EntitiesSubgraphRequest, only_with_metadata: bool, + only_agentic: bool, ) -> Relations: relations = Relations(entities={}) @@ -468,30 +475,33 @@ def _merge_relations_results( relation_type = RelationTypePbMap[relation.relation] relation_label = relation.relation_label metadata = relation.metadata if relation.HasField("metadata") else None - if (not only_with_metadata or metadata) and origin.value in relations.entities: - relations.entities[origin.value].related_to.append( - DirectionalRelation( - entity=destination.value, - entity_type=relation_node_type_to_entity_type(destination.ntype), - entity_subtype=destination.subtype, - relation=relation_type, - relation_label=relation_label, - direction=RelationDirection.OUT, - metadata=from_proto.relation_metadata(metadata) if metadata else None, + if (not only_with_metadata or metadata) and ( + not only_agentic or (metadata and metadata.data_augmentation_task_id) + ): + if origin.value in relations.entities: + relations.entities[origin.value].related_to.append( + DirectionalRelation( + entity=destination.value, + entity_type=relation_node_type_to_entity_type(destination.ntype), + entity_subtype=destination.subtype, + relation=relation_type, + relation_label=relation_label, + direction=RelationDirection.OUT, + metadata=from_proto.relation_metadata(metadata) if metadata else None, + ) ) - ) - elif (not only_with_metadata or metadata) and destination.value in relations.entities: - relations.entities[destination.value].related_to.append( - DirectionalRelation( - entity=origin.value, - entity_type=relation_node_type_to_entity_type(origin.ntype), - entity_subtype=origin.subtype, - relation=relation_type, - relation_label=relation_label, - direction=RelationDirection.IN, - metadata=from_proto.relation_metadata(metadata) if metadata else None, + elif destination.value in relations.entities: + relations.entities[destination.value].related_to.append( + DirectionalRelation( + entity=origin.value, + entity_type=relation_node_type_to_entity_type(origin.ntype), + entity_subtype=origin.subtype, + relation=relation_type, + relation_label=relation_label, + direction=RelationDirection.IN, + metadata=from_proto.relation_metadata(metadata) if metadata else None, + ) ) - ) return relations diff --git a/nucliadb/tests/nucliadb/integration/test_ask.py b/nucliadb/tests/nucliadb/integration/test_ask.py index c813bdc885..ca6353ef72 100644 --- a/nucliadb/tests/nucliadb/integration/test_ask.py +++ b/nucliadb/tests/nucliadb/integration/test_ask.py @@ -20,6 +20,7 @@ import json from itertools import combinations from unittest import mock +from unittest.mock import patch import pytest from httpx import AsyncClient @@ -47,6 +48,10 @@ RagStrategies, SyncAskResponse, ) +from nucliadb_protos.utils_pb2 import Relation, RelationMetadata, RelationNode +from nucliadb_protos.writer_pb2 import BrokerMessage +from tests.utils import inject_message +from tests.utils.dirty_index import wait_for_sync @pytest.fixture(scope="function", autouse=True) @@ -110,6 +115,91 @@ async def resource(nucliadb_writer, knowledgebox): yield rid +@pytest.fixture +async def graph_resource(nucliadb_writer, nucliadb_grpc, knowledgebox): + resp = await nucliadb_writer.post( + f"/kb/{knowledgebox}/resources", + json={ + "title": "Knowledge graph", + "slug": "knowledgegraph", + "summary": "Test knowledge graph", + "texts": { + "inception1": {"body": "Christopher Nolan directed Inception. Very interesting movie."}, + "inception2": {"body": "Leonardo DiCaprio starred in Inception."}, + "inception3": {"body": "Joseph Gordon-Levitt starred in Inception."}, + "leo": {"body": "Leonardo DiCaprio is a great actor. DiCaprio started acting in 1989."}, + }, + }, + ) + assert resp.status_code == 201 + rid = resp.json()["uuid"] + + nodes = { + "nolan": RelationNode( + value="Christopher Nolan", ntype=RelationNode.NodeType.ENTITY, subtype="DIRECTOR" + ), + "inception": RelationNode( + value="Inception", ntype=RelationNode.NodeType.ENTITY, subtype="MOVIE" + ), + "leo": RelationNode( + value="Leonardo DiCaprio", ntype=RelationNode.NodeType.ENTITY, subtype="ACTOR" + ), + "dicaprio": RelationNode(value="DiCaprio", ntype=RelationNode.NodeType.ENTITY, subtype="ACTOR"), + "levitt": RelationNode( + value="Joseph Gordon-Levitt", ntype=RelationNode.NodeType.ENTITY, subtype="ACTOR" + ), + } + edges = [ + Relation( + relation=Relation.RelationType.ENTITY, + source=nodes["nolan"], + to=nodes["inception"], + relation_label="directed", + metadata=RelationMetadata( + paragraph_id=rid + "/t/inception1/0-37", + data_augmentation_task_id="my_graph_task_id", + ), + ), + Relation( + relation=Relation.RelationType.ENTITY, + source=nodes["leo"], + to=nodes["inception"], + relation_label="starred", + metadata=RelationMetadata( + paragraph_id=rid + "/t/inception2/0-39", + data_augmentation_task_id="my_graph_task_id", + ), + ), + Relation( + relation=Relation.RelationType.ENTITY, + source=nodes["levitt"], + to=nodes["inception"], + relation_label="starred", + metadata=RelationMetadata( + paragraph_id=rid + "/t/inception3/0-42", + data_augmentation_task_id="", + ), + ), + Relation( + relation=Relation.RelationType.ENTITY, + source=nodes["leo"], + to=nodes["dicaprio"], + relation_label="analogy", + metadata=RelationMetadata( + paragraph_id=rid + "/t/leo/0-70", + data_augmentation_task_id="my_graph_task_id", + ), + ), + ] + bm = BrokerMessage() + bm.uuid = rid + bm.kbid = knowledgebox + bm.relations.extend(edges) + await inject_message(nucliadb_grpc, bm) + await wait_for_sync() + return rid + + async def test_ask_synchronous(nucliadb_reader: AsyncClient, knowledgebox, resource): resp = await nucliadb_reader.post( f"/kb/{knowledgebox}/ask", @@ -706,6 +796,74 @@ async def test_ask_top_k(nucliadb_reader: AsyncClient, knowledgebox, resources): assert ask_response.retrieval_results.best_matches[0] == prev_best_matches[0] +@pytest.mark.asyncio +@patch("nucliadb.search.search.graph_strategy.rank_relations") +async def test_ask_graph_strategy(mocker, nucliadb_reader: AsyncClient, knowledgebox, graph_resource): + # Mock the rank_relations function to return the same relations (no ranking) + # This function is already unit tested and requires predict + mocker.side_effect = lambda *args, **kwargs: args[0] + + data = { + "query": "Which actors have been in movies directed by Christopher Nolan?", + "rag_strategies": [ + { + "name": "graph", + "hops": 2, + "top_k": 5, + "agentic_graph_only": False, + } + ], + "debug": True, + } + headers = {"X-Synchronous": "True"} + + url = f"/kb/{knowledgebox}/ask" + + async def assert_ask(d, expected): + resp = await nucliadb_reader.post( + url, + json=d, + headers=headers, + ) + assert resp.status_code == 200, resp.text + ask_response = SyncAskResponse.model_validate_json(resp.content) + assert ask_response.status == "success" + paragraphs = ask_response.retrieval_results.resources[graph_resource].fields + paragraph_texts = { + p_id: paragraph.text + for p_id, field in paragraphs.items() + for paragraph in field.paragraphs.values() + } + assert paragraph_texts == expected + # We expect a ranking for each hop + assert mocker.call_count == 2 + mocker.reset_mock() + + expected = { + "/t/inception3": "Joseph Gordon-Levitt starred in Inception.", + "/t/inception2": "Leonardo DiCaprio starred in Inception.", + "/t/inception1": "Christopher Nolan directed Inception.", + } + await assert_ask(data, expected) + + data["query"] = "In which movie has DiCaprio starred? And Joseph Gordon-Levitt?" + expected = { + "/t/inception1": "Christopher Nolan directed Inception.", + "/t/inception3": "Joseph Gordon-Levitt starred in Inception.", + "/t/inception2": "Leonardo DiCaprio starred in Inception.", + "/t/leo": "Leonardo DiCaprio is a great actor. DiCaprio started acting in 1989.", + } + await assert_ask(data, expected) + + # Now with agentic graph only + data["rag_strategies"][0]["agentic_graph_only"] = True # type: ignore + expected = { + "/t/inception2": "Leonardo DiCaprio starred in Inception.", + "/t/leo": "Leonardo DiCaprio is a great actor. DiCaprio started acting in 1989.", + } + await assert_ask(data, expected) + + async def test_ask_rag_strategy_prequeries(nucliadb_reader: AsyncClient, knowledgebox, resources): resp = await nucliadb_reader.post( f"/kb/{knowledgebox}/ask", diff --git a/nucliadb/tests/search/unit/test_graph_strategy.py b/nucliadb/tests/search/unit/search/test_graph_strategy.py similarity index 100% rename from nucliadb/tests/search/unit/test_graph_strategy.py rename to nucliadb/tests/search/unit/search/test_graph_strategy.py diff --git a/nucliadb_models/src/nucliadb_models/search.py b/nucliadb_models/src/nucliadb_models/search.py index 03a8664f00..c7bab14823 100644 --- a/nucliadb_models/src/nucliadb_models/search.py +++ b/nucliadb_models/src/nucliadb_models/search.py @@ -1223,7 +1223,7 @@ class GraphStrategy(RagStrategy): name: Literal["graph"] = "graph" hops: int = Field( - default=1, + default=3, title="Number of hops", description="""Number of hops to take when exploring the graph for relevant context. For example, @@ -1234,11 +1234,16 @@ class GraphStrategy(RagStrategy): ge=1, ) top_k: int = Field( - default=20, + default=25, title="Top k", - description="Number of relationships to keep after each hop. This number correlates to more paragraphs being sent as context.", + description="Number of relationships to keep after each hop after ranking them by relevance to the query. This number correlates to more paragraphs being sent as context.", ge=1, ) + agentic_graph_only: bool = Field( + default=False, + title="Use only the graph extracted by an agent.", + description="If set to true, only entities extracted from a graph extraction agent are considered for context expansion.", + ) class TableImageStrategy(ImageRagStrategy): From 667db654455fac383851dc01b4bc93b01bc792f1 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Tue, 14 Jan 2025 11:14:31 +0100 Subject: [PATCH 11/17] remove todos --- .../nucliadb/search/search/graph_strategy.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/nucliadb/src/nucliadb/search/search/graph_strategy.py b/nucliadb/src/nucliadb/search/search/graph_strategy.py index 5c907fa722..f0fa5dfe7f 100644 --- a/nucliadb/src/nucliadb/search/search/graph_strategy.py +++ b/nucliadb/src/nucliadb/search/search/graph_strategy.py @@ -366,7 +366,6 @@ async def get_graph_results( relations.entities.update(new_subgraphs) # Rank the relevance of the relations - # TODO: Add upper bound to the number of entities to explore for safety with metrics.time("graph_strat_rank_relations"): try: relations = await rank_relations( @@ -446,8 +445,6 @@ async def fuzzy_search_entities( node_query(kbid, Method.SUGGEST, request, target_shard_replicas=target_shard_replicas) ) - # Gather - # TODO: What do I do with `incomplete_results`? try: results_raw = await asyncio.gather(*tasks) return await merge_suggest_results( @@ -473,13 +470,19 @@ async def rank_relations( max_rels_to_eval: int = 300, ) -> Relations: # Store the index for keeping track after scoring - # XXX: Here we set a hard limit on the number of relations to evaluate for safety and performance - # In the future we could to several iterations of scoring flat_rels: list[tuple[str, int, DirectionalRelation]] = [ (ent, idx, rel) for (ent, rels) in relations.entities.items() for (idx, rel) in enumerate(rels.related_to) - ][:max_rels_to_eval] + ] + + # XXX: Here we set a hard limit on the number of relations to evaluate for safety and performance + # In the future we could to several iterations of scoring + if len(flat_rels) > max_rels_to_eval: + logger.warning( + f"Too many relations to evaluate ({len(flat_rels)}), truncating to {max_rels_to_eval}" + ) + flat_rels = flat_rels[:max_rels_to_eval] triplets: list[dict[str, str]] = [ { "head_entity": ent, @@ -523,7 +526,7 @@ async def rank_relations( max_tokens=4096, generative_model=generative_model, ) - # TODO: Enclose this in a try-except block + ident, model, answer_stream = await predict.chat_query_ndjson(kbid, chat_model) response_json = None status = None @@ -538,11 +541,8 @@ async def rank_relations( elif isinstance(item, MetaGenerativeResponse): _ = item else: - # TODO: Improve for logging raise ValueError(f"Unknown generative chunk type: {item}") - # TODO: Report tokens using meta? - if response_json is None or status is None or status.code != "0": raise ValueError("No JSON response found") From 332da3c3ea3a0caa53d2d90fa9bf5d0fdcb5ddfc Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Tue, 14 Jan 2025 11:20:58 +0100 Subject: [PATCH 12/17] More coverage --- nucliadb/tests/nucliadb/integration/test_ask.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nucliadb/tests/nucliadb/integration/test_ask.py b/nucliadb/tests/nucliadb/integration/test_ask.py index ca6353ef72..857b859666 100644 --- a/nucliadb/tests/nucliadb/integration/test_ask.py +++ b/nucliadb/tests/nucliadb/integration/test_ask.py @@ -156,7 +156,8 @@ async def graph_resource(nucliadb_writer, nucliadb_grpc, knowledgebox): to=nodes["inception"], relation_label="directed", metadata=RelationMetadata( - paragraph_id=rid + "/t/inception1/0-37", + # Set this field id as int enum value since this is how legacy relations reported paragraph_id + paragraph_id=rid + "/4/inception1/0-37", data_augmentation_task_id="my_graph_task_id", ), ), From 1da11f8b154db34f1f1260e7b0aae78908e24bf5 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Wed, 15 Jan 2025 08:20:32 +0100 Subject: [PATCH 13/17] Add unit test for int field type ids --- .../tests/nucliadb/unit/common/test_ids.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/nucliadb/tests/nucliadb/unit/common/test_ids.py b/nucliadb/tests/nucliadb/unit/common/test_ids.py index fe69aa89d2..2e58fc5870 100644 --- a/nucliadb/tests/nucliadb/unit/common/test_ids.py +++ b/nucliadb/tests/nucliadb/unit/common/test_ids.py @@ -22,7 +22,13 @@ import pytest -from nucliadb.common.ids import FieldId, ParagraphId, VectorId, extract_data_augmentation_id +from nucliadb.common.ids import ( + FIELD_TYPE_PB_TO_STR, + FieldId, + ParagraphId, + VectorId, + extract_data_augmentation_id, +) from nucliadb_protos.resources_pb2 import FieldType @@ -58,6 +64,18 @@ def test_field_ids(): assert field_id.full() == "rid/u/field_id/subfield_id" +def test_field_ids_int_field_type(): + # Test that we can use integers as field types + for value in FieldType.values(): + field_id = FieldId.from_string(f"rid/{value}/field_id/subfield_id") + assert field_id.rid == "rid" + assert field_id.type == FIELD_TYPE_PB_TO_STR[value] + assert field_id.key == "field_id" + assert field_id.subfield_id == "subfield_id" + assert field_id.full() == f"rid/{FIELD_TYPE_PB_TO_STR[value]}/field_id/subfield_id" + assert field_id.pb_type == value + + def test_paragraph_ids(): invalids = [ "foobar", From a890893af889af67b17358f867a15a6cd7ac449d Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Wed, 15 Jan 2025 08:26:35 +0100 Subject: [PATCH 14/17] Fixes from feedback --- nucliadb/src/nucliadb/search/api/v1/suggest.py | 2 -- nucliadb/src/nucliadb/search/search/graph_strategy.py | 2 -- nucliadb/src/nucliadb/search/search/merge.py | 2 -- nucliadb_models/src/nucliadb_models/search.py | 2 ++ 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/nucliadb/src/nucliadb/search/api/v1/suggest.py b/nucliadb/src/nucliadb/search/api/v1/suggest.py index efe1834e00..02b147901a 100644 --- a/nucliadb/src/nucliadb/search/api/v1/suggest.py +++ b/nucliadb/src/nucliadb/search/api/v1/suggest.py @@ -151,8 +151,6 @@ async def suggest( search_results = await merge_suggest_results( results, kbid=kbid, - show=show, - field_type_filter=field_type_filter, highlight=highlight, ) diff --git a/nucliadb/src/nucliadb/search/search/graph_strategy.py b/nucliadb/src/nucliadb/search/search/graph_strategy.py index f0fa5dfe7f..e62d01a104 100644 --- a/nucliadb/src/nucliadb/search/search/graph_strategy.py +++ b/nucliadb/src/nucliadb/search/search/graph_strategy.py @@ -450,8 +450,6 @@ async def fuzzy_search_entities( return await merge_suggest_results( [item for r in results_raw for item in r[0]], kbid=kbid, - show=show, - field_type_filter=field_type_filter, ) except Exception as e: capture_exception(e) diff --git a/nucliadb/src/nucliadb/search/search/merge.py b/nucliadb/src/nucliadb/search/search/merge.py index fa9f26f3ab..ac62e930bc 100644 --- a/nucliadb/src/nucliadb/search/search/merge.py +++ b/nucliadb/src/nucliadb/search/search/merge.py @@ -603,8 +603,6 @@ async def merge_suggest_entities_results( async def merge_suggest_results( suggest_responses: list[SuggestResponse], kbid: str, - show: list[ResourceProperties], - field_type_filter: list[FieldTypeName], highlight: bool = False, ) -> KnowledgeboxSuggestResults: api_results = KnowledgeboxSuggestResults() diff --git a/nucliadb_models/src/nucliadb_models/search.py b/nucliadb_models/src/nucliadb_models/search.py index c7bab14823..71249c33c1 100644 --- a/nucliadb_models/src/nucliadb_models/search.py +++ b/nucliadb_models/src/nucliadb_models/search.py @@ -1232,12 +1232,14 @@ class GraphStrategy(RagStrategy): And so on. Bigger values will discover more intricate relationships but will also take more time to compute.""", ge=1, + le=10, ) top_k: int = Field( default=25, title="Top k", description="Number of relationships to keep after each hop after ranking them by relevance to the query. This number correlates to more paragraphs being sent as context.", ge=1, + le=120, ) agentic_graph_only: bool = Field( default=False, From bab3fc4e8080ce9799f43603d6c5354344a66a55 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Wed, 15 Jan 2025 08:28:29 +0100 Subject: [PATCH 15/17] Refactor --- nucliadb/src/nucliadb/search/search/graph_strategy.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nucliadb/src/nucliadb/search/search/graph_strategy.py b/nucliadb/src/nucliadb/search/search/graph_strategy.py index e62d01a104..152c402682 100644 --- a/nucliadb/src/nucliadb/search/search/graph_strategy.py +++ b/nucliadb/src/nucliadb/search/search/graph_strategy.py @@ -305,8 +305,6 @@ async def get_graph_results( suggest_result = await fuzzy_search_entities( kbid=kbid, query=query, - show=[], - field_type_filter=item.field_type_filter, range_creation_start=item.range_creation_start, range_creation_end=item.range_creation_end, range_modification_start=item.range_modification_start, @@ -411,8 +409,6 @@ async def get_graph_results( async def fuzzy_search_entities( kbid: str, query: str, - show: list[ResourceProperties], - field_type_filter: list[FieldTypeName], range_creation_start: Optional[datetime] = None, range_creation_end: Optional[datetime] = None, range_modification_start: Optional[datetime] = None, From c4c05c47d1cca8572de49de3d7abd6864aaf56c9 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Wed, 15 Jan 2025 08:40:34 +0100 Subject: [PATCH 16/17] Document merge relations --- nucliadb/src/nucliadb/search/search/merge.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/nucliadb/src/nucliadb/search/search/merge.py b/nucliadb/src/nucliadb/search/search/merge.py index ac62e930bc..9d8161cf9d 100644 --- a/nucliadb/src/nucliadb/search/search/merge.py +++ b/nucliadb/src/nucliadb/search/search/merge.py @@ -463,6 +463,18 @@ def _merge_relations_results( only_with_metadata: bool, only_agentic: bool, ) -> Relations: + """ + Merge relation search responses into a single Relations object while applying filters. + + Args: + relations_responses: List of relation search responses + query: EntitiesSubgraphRequest object + only_with_metadata: If True, only include relations with metadata. This metadata includes paragraph_id and entity positions among other things. + only_agentic: If True, only include relations extracted by a Graph Extraction Agent. + + Returns: + Relations + """ relations = Relations(entities={}) for entry_point in query.entry_points: @@ -475,6 +487,8 @@ def _merge_relations_results( relation_type = RelationTypePbMap[relation.relation] relation_label = relation.relation_label metadata = relation.metadata if relation.HasField("metadata") else None + # If only_with_metadata is True, we check that metadata for the relation is not None + # If only_agentic is True, we check that metadata for the relation is not None and that it has a data_augmentation_task_id if (not only_with_metadata or metadata) and ( not only_agentic or (metadata and metadata.data_augmentation_task_id) ): From 2719ba64254d7dec88e524c65b6208a211590555 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Wed, 15 Jan 2025 14:51:23 +0100 Subject: [PATCH 17/17] Addressed comments --- nucliadb/src/nucliadb/search/search/find_merge.py | 2 ++ nucliadb/src/nucliadb/search/search/graph_strategy.py | 2 +- nucliadb/src/nucliadb/search/search/merge.py | 1 + nucliadb_models/src/nucliadb_models/search.py | 6 +++--- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/nucliadb/src/nucliadb/search/search/find_merge.py b/nucliadb/src/nucliadb/search/search/find_merge.py index 8e709ee50f..911058edce 100644 --- a/nucliadb/src/nucliadb/search/search/find_merge.py +++ b/nucliadb/src/nucliadb/search/search/find_merge.py @@ -231,6 +231,8 @@ def paragraph_id_to_text_block_match(paragraph_id: str) -> TextBlockMatch: Given a paragraph_id, return a TextBlockMatch with the bare minimum fields This is required by the Graph Strategy to get text blocks from the relevant paragraphs """ + # XXX: this is a workaround for the fact we always assume retrieval means keyword/semantic search and + # the hydration and find response building code works with TextBlockMatch parsed_paragraph_id = ParagraphId.from_string(paragraph_id) return TextBlockMatch( paragraph_id=parsed_paragraph_id, diff --git a/nucliadb/src/nucliadb/search/search/graph_strategy.py b/nucliadb/src/nucliadb/search/search/graph_strategy.py index 152c402682..0519ae0a8c 100644 --- a/nucliadb/src/nucliadb/search/search/graph_strategy.py +++ b/nucliadb/src/nucliadb/search/search/graph_strategy.py @@ -432,7 +432,7 @@ async def fuzzy_search_entities( tasks = [] # XXX: Splitting by words is not ideal, in the future, modify suggest to better handle this for word in query.split(): - if len(word) <= 3: + if len(word) < 3: continue request = nodereader_pb2.SuggestRequest() request.CopyFrom(base_request) diff --git a/nucliadb/src/nucliadb/search/search/merge.py b/nucliadb/src/nucliadb/search/search/merge.py index 9d8161cf9d..b95c6fa5c7 100644 --- a/nucliadb/src/nucliadb/search/search/merge.py +++ b/nucliadb/src/nucliadb/search/search/merge.py @@ -489,6 +489,7 @@ def _merge_relations_results( metadata = relation.metadata if relation.HasField("metadata") else None # If only_with_metadata is True, we check that metadata for the relation is not None # If only_agentic is True, we check that metadata for the relation is not None and that it has a data_augmentation_task_id + # TODO: This is suboptimal, we should be able to filter this in the query to the index, if (not only_with_metadata or metadata) and ( not only_agentic or (metadata and metadata.data_augmentation_task_id) ): diff --git a/nucliadb_models/src/nucliadb_models/search.py b/nucliadb_models/src/nucliadb_models/search.py index 71249c33c1..97ff5142ab 100644 --- a/nucliadb_models/src/nucliadb_models/search.py +++ b/nucliadb_models/src/nucliadb_models/search.py @@ -962,7 +962,7 @@ class ChatModel(BaseModel): format_prompt: bool = Field( default=True, - description="If set to false, the prompt will be used as is, without any formatting for query or context", + description="If set to false, the prompt given as `user_prompt` will be used as is, without any formatting for question or context. If set to true, the prompt must contain the placeholders {question} and {context} to be replaced by the question and context respectively", # noqa: E501 ) @@ -1244,7 +1244,7 @@ class GraphStrategy(RagStrategy): agentic_graph_only: bool = Field( default=False, title="Use only the graph extracted by an agent.", - description="If set to true, only entities extracted from a graph extraction agent are considered for context expansion.", + description="If set to true, only relationships extracted from a graph extraction agent are considered for context expansion.", ) @@ -1424,7 +1424,7 @@ class AskRequest(AuditMetadataBase): - `neighbouring_paragraphs` will add the sorrounding paragraphs to the context for each matching paragraph. - `metadata_extension` will add the metadata of the matching paragraphs or its resources to the context. - `prequeries` allows to run multiple retrieval queries before the main query and add the results to the context. The results of specific queries can be boosted by the specifying weights. -- `graph` will retrieve context pieces by exploring the Knowledge Graph, starting from the entities present in the query. +- `graph` will retrieve context pieces by exploring the Knowledge Graph, starting from the entities present in the query. This strategy is not compatible with the `prequeries` strategy. If empty, the default strategy is used, which simply adds the text of the matching paragraphs to the context. """