Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add graph strategy #2772

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions nucliadb/src/nucliadb/common/ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,11 @@ def from_string(cls, value: str) -> "FieldId":
parts = value.split("/")
if len(parts) == 3:
rid, _type, key = parts
if _type not in FIELD_TYPE_STR_TO_PB:
raise ValueError(f"Invalid FieldId: {value}")
_type = cls.parse_field_type(_type)
return cls(rid=rid, type=_type, key=key)
elif len(parts) == 4:
rid, _type, key, subfield_id = parts
if _type not in FIELD_TYPE_STR_TO_PB:
raise ValueError(f"Invalid FieldId: {value}")
_type = cls.parse_field_type(_type)
return cls(
rid=rid,
type=_type,
Expand All @@ -127,6 +125,22 @@ def from_string(cls, value: str) -> "FieldId":
else:
raise ValueError(f"Invalid FieldId: {value}")

@classmethod
def parse_field_type(cls, _type: str) -> str:
if _type not in FIELD_TYPE_STR_TO_PB:
# Try to parse the enum value
# XXX: This is to support field types that are integer values of FieldType
# Which is how legacy processor relations reported the paragraph_id
try:
type_pb = FieldType.ValueType(int(_type))
except ValueError:
raise ValueError(f"Invalid FieldId: {_type}")
lferran marked this conversation as resolved.
Show resolved Hide resolved
if type_pb in FIELD_TYPE_PB_TO_STR:
return FIELD_TYPE_PB_TO_STR[type_pb]
else:
raise ValueError(f"Invalid FieldId: {_type}")
return _type

carlesonielfa marked this conversation as resolved.
Show resolved Hide resolved

@dataclass
class ParagraphId:
Expand Down
2 changes: 0 additions & 2 deletions nucliadb/src/nucliadb/search/api/v1/suggest.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ async def suggest(
search_results = await merge_suggest_results(
results,
kbid=kbid,
show=show,
field_type_filter=field_type_filter,
highlight=highlight,
)

Expand Down
45 changes: 35 additions & 10 deletions nucliadb/src/nucliadb/search/search/chat/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
IncompleteFindResultsError,
InvalidQueryError,
)
from nucliadb.search.search.graph_strategy import get_graph_results
from nucliadb.search.search.metrics import RAGMetrics
from nucliadb.search.search.query import QueryParser
from nucliadb.search.utilities import get_predict
Expand All @@ -75,6 +76,7 @@
ErrorAskResponseItem,
FindParagraph,
FindRequest,
GraphStrategy,
JSONAskResponseItem,
KnowledgeboxFindResults,
MetadataAskResponseItem,
Expand Down Expand Up @@ -629,6 +631,13 @@ def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
return None


def parse_graph_strategy(ask_request: AskRequest) -> Optional[GraphStrategy]:
for rag_strategy in ask_request.rag_strategies:
if rag_strategy.name == RagStrategyName.GRAPH:
return cast(GraphStrategy, rag_strategy)
return None


async def retrieval_step(
kbid: str,
main_query: str,
Expand Down Expand Up @@ -675,17 +684,33 @@ async def retrieval_in_kb(
metrics: RAGMetrics,
) -> RetrievalResults:
prequeries = parse_prequeries(ask_request)
graph_strategy = parse_graph_strategy(ask_request)
with metrics.time("retrieval"):
main_results, prequeries_results, query_parser = await get_find_results(
kbid=kbid,
query=main_query,
item=ask_request,
ndb_client=client_type,
user=user_id,
origin=origin,
metrics=metrics,
prequeries_strategy=prequeries,
)
prequeries_results = None
if graph_strategy is not None:
main_results, query_parser = await get_graph_results(
kbid=kbid,
query=main_query,
item=ask_request,
ndb_client=client_type,
user=user_id,
origin=origin,
graph_strategy=graph_strategy,
metrics=metrics,
shards=ask_request.shards,
)
# TODO (oni): Fallback to normal retrieval if no graph results are found
else:
main_results, prequeries_results, query_parser = await get_find_results(
kbid=kbid,
query=main_query,
item=ask_request,
ndb_client=client_type,
user=user_id,
origin=origin,
metrics=metrics,
prequeries_strategy=prequeries,
)
if len(main_results.resources) == 0 and all(
len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
):
Expand Down
6 changes: 4 additions & 2 deletions nucliadb/src/nucliadb/search/search/chat/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,8 +1013,10 @@ async def _build_context(self, context: CappedPromptContext) -> None:
neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
elif strategy.name == RagStrategyName.METADATA_EXTENSION:
metadata_extension = cast(MetadataExtensionStrategy, strategy)
elif strategy.name != RagStrategyName.PREQUERIES: # pragma: no cover
# Prequeries are not handled here
elif (
strategy.name != RagStrategyName.PREQUERIES and strategy.name != RagStrategyName.GRAPH
): # pragma: no cover
# Prequeries and graph are not handled here
logger.warning(
"Unknown rag strategy",
extra={"strategy": strategy.name, "kbid": self.kbid},
Expand Down
84 changes: 56 additions & 28 deletions nucliadb/src/nucliadb/search/search/chat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
import asyncio
from typing import Optional
from typing import Iterable, Optional

from nucliadb.common.models_utils import to_proto
from nucliadb.search import logger
Expand Down Expand Up @@ -51,6 +51,7 @@
)
from nucliadb_protos import audit_pb2
from nucliadb_protos.nodereader_pb2 import RelationSearchResponse, SearchRequest, SearchResponse
from nucliadb_protos.utils_pb2 import RelationNode
from nucliadb_telemetry.errors import capture_exception
from nucliadb_utils.utilities import get_audit

Expand Down Expand Up @@ -145,15 +146,7 @@ async def get_find_results(
return main_results, prequeries_results, query_parser


async def run_main_query(
kbid: str,
query: str,
item: AskRequest,
ndb_client: NucliaDBClientType,
user: str,
origin: str,
metrics: RAGMetrics = RAGMetrics(),
) -> tuple[KnowledgeboxFindResults, QueryParser]:
def find_request_from_ask_request(item: AskRequest, query: str) -> FindRequest:
find_request = FindRequest()
find_request.resource_filters = item.resource_filters
find_request.features = []
Expand Down Expand Up @@ -189,7 +182,19 @@ async def run_main_query(
find_request.show_hidden = item.show_hidden

# this executes the model validators, that can tweak some fields
FindRequest.model_validate(find_request)
return FindRequest.model_validate(find_request)


async def run_main_query(
kbid: str,
query: str,
item: AskRequest,
ndb_client: NucliaDBClientType,
user: str,
origin: str,
metrics: RAGMetrics = RAGMetrics(),
) -> tuple[KnowledgeboxFindResults, QueryParser]:
find_request = find_request_from_ask_request(item, query)

find_results, incomplete, query_parser = await find(
kbid,
Expand All @@ -211,36 +216,59 @@ async def get_relations_results(
text_answer: str,
target_shard_replicas: Optional[list[str]],
timeout: Optional[float] = None,
only_with_metadata: bool = False,
only_agentic_relations: bool = False,
) -> Relations:
try:
predict = get_predict()
detected_entities = await predict.detect_entities(kbid, text_answer)
request = SearchRequest()
request.relation_subgraph.entry_points.extend(detected_entities)
request.relation_subgraph.depth = 1

results: list[SearchResponse]
(
results,
_,
_,
) = await node_query(
kbid,
Method.SEARCH,
request,

return await get_relations_results_from_entities(
kbid=kbid,
entities=detected_entities,
target_shard_replicas=target_shard_replicas,
timeout=timeout,
use_read_replica_nodes=True,
retry_on_primary=False,
only_with_metadata=only_with_metadata,
only_agentic_relations=only_agentic_relations,
)
relations_results: list[RelationSearchResponse] = [result.relation for result in results]
return await merge_relations_results(relations_results, request.relation_subgraph)
except Exception as exc:
capture_exception(exc)
logger.exception("Error getting relations results")
return Relations(entities={})


async def get_relations_results_from_entities(
*,
kbid: str,
entities: Iterable[RelationNode],
target_shard_replicas: Optional[list[str]],
timeout: Optional[float] = None,
only_with_metadata: bool = False,
only_agentic_relations: bool = False,
) -> Relations:
request = SearchRequest()
request.relation_subgraph.entry_points.extend(entities)
request.relation_subgraph.depth = 1
results: list[SearchResponse]
(
results,
_,
_,
) = await node_query(
kbid,
Method.SEARCH,
request,
target_shard_replicas=target_shard_replicas,
timeout=timeout,
use_read_replica_nodes=True,
retry_on_primary=False,
)
relations_results: list[RelationSearchResponse] = [result.relation for result in results]
return await merge_relations_results(
relations_results, request.relation_subgraph, only_with_metadata, only_agentic_relations
)


def maybe_audit_chat(
*,
kbid: str,
Expand Down
33 changes: 33 additions & 0 deletions nucliadb/src/nucliadb/search/search/find_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,39 @@ def merge_shards_relation_responses(
return merged


def paragraph_id_to_text_block_match(paragraph_id: str) -> TextBlockMatch:
"""
Given a paragraph_id, return a TextBlockMatch with the bare minimum fields
This is required by the Graph Strategy to get text blocks from the relevant paragraphs
"""
parsed_paragraph_id = ParagraphId.from_string(paragraph_id)
return TextBlockMatch(
paragraph_id=parsed_paragraph_id,
score=0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the bm25 score from the graph index is not used? are we assuming we're always reranking with an llm with the graph rag strategy?

score_type=SCORE_TYPE.BM25,
order=0, # NOTE: this will be filled later
text="", # NOTE: this will be filled later too
position=TextPosition(
page_number=0,
index=0,
start=parsed_paragraph_id.paragraph_start,
end=parsed_paragraph_id.paragraph_end,
start_seconds=[],
end_seconds=[],
),
field_labels=[],
paragraph_labels=[],
fuzzy_search=False,
is_a_table=False,
representation_file="",
page_with_visual=False,
)


def paragraph_id_to_text_block_matches(paragraph_ids: Iterable[str]) -> list[TextBlockMatch]:
return [paragraph_id_to_text_block_match(item) for item in paragraph_ids]


def keyword_result_to_text_block_match(item: ParagraphResult) -> TextBlockMatch:
fuzzy_result = len(item.matches) > 0
return TextBlockMatch(
Expand Down
Loading
Loading