diff --git a/backend/director/agents/index.py b/backend/director/agents/index.py index 28e0b9e..a385fbe 100644 --- a/backend/director/agents/index.py +++ b/backend/director/agents/index.py @@ -51,6 +51,12 @@ "default": "shot", "description": "Method to use for scene detection and frame extraction", }, + "model_name": { + "type": "string", + "description": "The name of the model to use for scene detection and frame extraction", + "default": "gpt4-o", + "enum": ["gemini-1.5-flash", "gemini-1.5-pro", "gpt4-o"], + }, "shot_based_config": { "type": "object", "description": "Configuration for shot-based scene detection and frame extraction, This is a required parameter for shot_based indexing", @@ -246,6 +252,9 @@ def run( elif index_type == "scene": scene_index_type = scene_index_config["type"] + scene_index_model_name = scene_index_config.get( + "model_name", "gpt4-o" + ) scene_index_config = scene_index_config[ scene_index_type + "_based_config" ] @@ -254,6 +263,7 @@ def run( extraction_type=scene_index_type, extraction_config=scene_index_config, prompt=scene_index_prompt, + model_name=scene_index_model_name, ) self.videodb_tool.get_scene_index( video_id=video_id, scene_id=scene_index_id diff --git a/backend/director/tools/videodb_tool.py b/backend/director/tools/videodb_tool.py index 68d4b18..f96722f 100644 --- a/backend/director/tools/videodb_tool.py +++ b/backend/director/tools/videodb_tool.py @@ -150,6 +150,7 @@ def index_scene( video_id: str, extraction_type=SceneExtractionType.shot_based, extraction_config={}, + model_name=None, prompt=None, ): video = self.collection.get_video(video_id) @@ -157,6 +158,7 @@ def index_scene( extraction_type=extraction_type, extraction_config=extraction_config, prompt=prompt, + model_name=model_name, ) def list_scene_index(self, video_id: str):