Skip to content

Commit

Permalink
Merge pull request #105 from video-db/add-support-for-fal
Browse files Browse the repository at this point in the history
feat(agent) : Comparison agent & fal.ai tool
  • Loading branch information
ankit-v2-3 authored Dec 20, 2024
2 parents d85932a + 42ffa85 commit 604df70
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 19 deletions.
4 changes: 4 additions & 0 deletions backend/.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ STABILITYAI_API_KEY=
KLING_AI_ACCESS_API_KEY=
KLING_AI_SECRET_API_KEY=

# FAL AI
FAL_KEY=

# Composio Agent
# https://composio.dev/tools
COMPOSIO_API_KEY=
COMPOSIO_APPS=["HACKERNEWS"]

146 changes: 146 additions & 0 deletions backend/director/agents/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import logging
import concurrent.futures
import queue

from director.agents.base import BaseAgent, AgentResponse, AgentStatus
from director.core.session import Session, VideosContent, VideoData, MsgStatus
from director.agents.video_generation import VideoGenerationAgent
from director.agents.video_generation import VIDEO_GENERATION_AGENT_PARAMETERS

logger = logging.getLogger(__name__)

COMPARISON_AGENT_PARAMETERS = {
"type": "object",
"properties": {
"job_type": {
"type": "string",
"enum": ["video_generation_comparison"],
"description": "Creates videos using MULTIPLE video generation models/engines. This agent should be used in two scenarios: 1) When request contains model names connected by words like 'and', '&', 'with', ',', 'versus', 'vs' (e.g. 'using Stability and Kling'), 2) When request mentions comparing/testing multiple models even if mentioned later in the prompt (e.g. 'Generate X. Compare results from Y, Z'). If the request suggests using more than one model in any way, use this agent rather than calling video_generation agent multiple times.",
},
"video_generation_comparison": {
"type": "array",
"items": {
"type": "object",
"properties": {
"description": {
"type": "string",
"description": "Description of the video generation run, Mention the configuration picked for the run. Keep the engine name and parameter config in a separate line. Keep it short. Here's an example: 'Tokyo Sunset - Luma - Prompt: 'An aerial shot of a quiet sunset at Tokyo', Duration: 5s, Luma Dream Machine",
},
**VIDEO_GENERATION_AGENT_PARAMETERS["properties"],
},
"required": [
"description",
*VIDEO_GENERATION_AGENT_PARAMETERS["required"],
],
"description": "Parameters to use for each video generation run, each object in this is the parameters that would be required for each @video_generation run",
},
"description": "List of parameters to use for each video generation run, each object in this is the parameters that would be required for each @video_generation run",
},
},
"required": ["job_type", "video_generation_comparison"],
}


class ComparisonAgent(BaseAgent):
def __init__(self, session: Session, **kwargs):
self.agent_name = "comparison"
self.description = """Primary agent for video generation from prompts. Handles all video creation requests including single and multi-model generation. If multiple models or variations are mentioned, automatically parallelizes the work. For single model requests, delegates to specialized video generation subsystem. Keywords: generate video, create video, make video, text to video. """

self.parameters = COMPARISON_AGENT_PARAMETERS
super().__init__(session=session, **kwargs)

def _run_video_generation(self, index, params):
"""Helper method to run video generation with given params"""
video_gen_agent = VideoGenerationAgent(session=self.session)
return (index, video_gen_agent.run(**params, stealth_mode=True))

def done_callback(self, fut):
result = fut.result()
self.notification_queue.put(result)

def _update_videos_content(self, videos_content, index, result):
if result.status == AgentStatus.SUCCESS:
videos_content.videos[index] = result.data["video_content"].video
elif result.status == AgentStatus.ERROR:
videos_content.videos[index] = VideoData(
name=f"[Error] {videos_content.videos[index].name}",
stream_url="",
id=None,
collection_id=None,
)
self.output_message.push_update()

def run(
self, job_type: str, video_generation_comparison: list, *args, **kwargs
) -> AgentResponse:
"""
Compare outputs from multiple runs of video generation
:param str job_type: Type of comparison to perform
:param list video_generation_comparison: Parameters for video gen runs
:param args: Additional positional arguments
:param kwargs: Additional keyword arguments
:return: Response containing comparison results
:rtype: AgentResponse
"""
try:
if job_type == "video_generation_comparison":
videos_content = VideosContent(
agent_name=self.agent_name,
status=MsgStatus.progress,
status_message="Generating videos (Usually takes 3-7 mins)",
videos=[],
)
self.notification_queue = queue.Queue()

for params in video_generation_comparison:
video_data = VideoData(
name=params["text_to_video"]["name"],
stream_url="",
)
videos_content.videos.append(video_data)

self.output_message.content.append(videos_content)
self.output_message.push_update()

# Use ThreadPoolExecutor to run video generations in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
# Submit all tasks and get future objects
futures = []
for index, params in enumerate(video_generation_comparison):
future = executor.submit(
self._run_video_generation, index, params
)
future.add_done_callback(self.done_callback)
futures.append(future)

# Process completed tasks as they finish
completed_count = 0
total = len(futures)

while completed_count < total:
res = self.notification_queue.get()
self._update_videos_content(videos_content, res[0], res[1])
completed_count += 1

for future in concurrent.futures.as_completed(futures):
try:
videos_content.status = MsgStatus.success
videos_content.status_message = (
"Here are your generated videos"
)
self.output_message.push_update()
except Exception as e:
logger.exception(f"Error processing task: {e}")

return AgentResponse(
status=AgentStatus.SUCCESS,
message="Video generation comparison complete",
data={"videos": videos_content},
)
else:
raise Exception(f"Unsupported comparison type: {job_type}")

except Exception as e:
logger.exception(f"Error in {self.agent_name} agent: {e}")
return AgentResponse(status=AgentStatus.ERROR, message=str(e))
74 changes: 57 additions & 17 deletions backend/director/agents/video_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
PARAMS_CONFIG as STABILITYAI_PARAMS_CONFIG,
)
from director.tools.kling import KlingAITool, PARAMS_CONFIG as KLING_PARAMS_CONFIG

from director.tools.fal_video import (
FalVideoGenerationTool,
PARAMS_CONFIG as FAL_VIDEO_GEN_PARAMS_CONFIG,
)
from director.constants import DOWNLOADS_PATH

logger = logging.getLogger(__name__)

SUPPORTED_ENGINES = ["stabilityai", "kling"]
SUPPORTED_ENGINES = ["stabilityai", "kling", "fal"]

VIDEO_GENERATION_AGENT_PARAMETERS = {
"type": "object",
Expand All @@ -28,9 +31,9 @@
},
"engine": {
"type": "string",
"description": "The video generation engine to use",
"default": "stabilityai",
"enum": ["stabilityai", "kling"],
"description": "The video generation engine to use. Use Fal by default. If the query includes any of the following: 'minimax-video, mochi-v1, hunyuan-video, luma-dream-machine, cogvideox-5b, ltx-video, fast-svd, fast-svd-lcm, t2v-turbo, kling video v 1.0, kling video v1.5 pro, fast-animatediff, fast-animatediff turbo, and animatediff-sparsectrl-lcm'- always use Fal. In case user specifies any other engine, use the supported engines like Stability.",
"default": "fal",
"enum": ["fal", "stabilityai"],
},
"job_type": {
"type": "string",
Expand All @@ -48,6 +51,10 @@
"type": "string",
"description": "The text prompt to generate the video",
},
"name": {
"type": "string",
"description": "Description of the video generation run in two lines. Keep the engine name and parameter configuration of the engine in separate lines. Keep it short, but show the prompt in full. Here's an example: [Tokyo Sunset - Luma - Prompt: 'An aerial shot of a quiet sunset at Tokyo', Duration: 5s, Luma Dream Machine]",
},
"duration": {
"type": "number",
"description": "The duration of the video in seconds",
Expand All @@ -63,8 +70,13 @@
"properties": KLING_PARAMS_CONFIG["text_to_video"],
"description": "Config to use when kling engine is used",
},
"fal_config": {
"type": "object",
"properties": FAL_VIDEO_GEN_PARAMS_CONFIG["text_to_video"],
"description": "Config to use when fal engine is used",
},
},
"required": ["prompt"],
"required": ["prompt", "name"],
},
},
"required": ["job_type", "collection_id", "engine"],
Expand All @@ -74,7 +86,7 @@
class VideoGenerationAgent(BaseAgent):
def __init__(self, session: Session, **kwargs):
self.agent_name = "video_generation"
self.description = "Agent designed to generate videos from text prompts"
self.description = "Creates videos using ONE specific model/engine. Only use this agent when the request mentions exactly ONE model/engine, without any comparison words like 'compare', 'test', 'versus', 'vs' and no connecting words (and/&/,) between model names. If the request mentions wanting to compare models or try multiple engines, do not use this agent - use the comparison agent instead."
self.parameters = VIDEO_GENERATION_AGENT_PARAMETERS
super().__init__(session=session, **kwargs)

Expand All @@ -99,6 +111,7 @@ def run(
"""
try:
self.videodb_tool = VideoDBTool(collection_id=collection_id)
stealth_mode = kwargs.get("stealth_mode", False)

if engine not in SUPPORTED_ENGINES:
raise Exception(f"{engine} not supported")
Expand All @@ -119,6 +132,12 @@ def run(
secret_key=KLING_AI_SECRET_API_KEY,
)
config_key = "kling_config"
elif engine == "fal":
FAL_KEY = os.getenv("FAL_KEY")
if not FAL_KEY:
raise Exception("FAL API key not found")
video_gen_tool = FalVideoGenerationTool(api_key=FAL_KEY)
config_key = "fal_config"
else:
raise Exception(f"{engine} not supported")

Expand All @@ -131,18 +150,21 @@ def run(
status=MsgStatus.progress,
status_message="Processing...",
)
self.output_message.content.append(video_content)
if not stealth_mode:
self.output_message.content.append(video_content)

if job_type == "text_to_video":
prompt = text_to_video.get("prompt")
video_name = text_to_video.get("name")
duration = text_to_video.get("duration", 5)
config = text_to_video.get(config_key, {})
if prompt is None:
raise Exception("Prompt is required for video generation")
self.output_message.actions.append(
f"Generating video using <b>{engine}</b> for prompt <i>{prompt}</i>"
)
self.output_message.push_update()
if not stealth_mode:
self.output_message.push_update()
video_gen_tool.text_to_video(
prompt=prompt,
save_at=output_path,
Expand All @@ -155,32 +177,50 @@ def run(
self.output_message.actions.append(
f"Generated video saved at <i>{output_path}</i>"
)
self.output_message.push_update()
if not stealth_mode:
self.output_message.push_update()

# Upload to VideoDB
media = self.videodb_tool.upload(
output_path, source_type="file_path", media_type="video"
output_path,
source_type="file_path",
media_type="video",
name=video_name,
)
self.output_message.actions.append(
f"Uploaded generated video to VideoDB with Video ID {media['id']}"
)
stream_url = media["stream_url"]
video_content.video = VideoData(stream_url=stream_url)
id = media["id"]
collection_id = media["collection_id"]
name = media["name"]
video_content.video = VideoData(
stream_url=stream_url,
id=id,
collection_id=collection_id,
name=name,
)
video_content.status = MsgStatus.success
video_content.status_message = "Here is your generated video"
self.output_message.push_update()
self.output_message.publish()
if not stealth_mode:
self.output_message.push_update()
self.output_message.publish()

except Exception as e:
logger.exception(f"Error in {self.agent_name} agent: {e}")
video_content.status = MsgStatus.error
video_content.status_message = "Failed to generate video"
self.output_message.push_update()
self.output_message.publish()
if not stealth_mode:
self.output_message.push_update()
self.output_message.publish()
return AgentResponse(status=AgentStatus.ERROR, message=str(e))

return AgentResponse(
status=AgentStatus.SUCCESS,
message=f"Generated video ID {media['id']}",
data={"video_id": media["id"], "video_stream_url": stream_url},
data={
"video_id": media["id"],
"video_stream_url": stream_url,
"video_content": video_content,
},
)
13 changes: 11 additions & 2 deletions backend/director/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ContentType(str, Enum):

text = "text"
video = "video"
videos = "videos"
image = "image"
search_results = "search_results"

Expand Down Expand Up @@ -69,7 +70,8 @@ class TextContent(BaseContent):
class VideoData(BaseModel):
"""Video data model class for video content."""

stream_url: str
stream_url: Optional[str] = None
external_url: Optional[str] = None
player_url: Optional[str] = None
id: Optional[str] = None
collection_id: Optional[str] = None
Expand All @@ -86,6 +88,13 @@ class VideoContent(BaseContent):
type: ContentType = ContentType.video


class VideosContent(BaseContent):
"""Videos content model class for videos content."""

videos: Optional[List[VideoData]] = None
type: ContentType = ContentType.videos


class ImageData(BaseModel):
"""Image data model class for image content."""

Expand Down Expand Up @@ -142,7 +151,7 @@ class BaseMessage(BaseModel):
actions: List[str] = []
agents: List[str] = []
content: List[
Union[dict, TextContent, ImageContent, VideoContent, SearchResultsContent]
Union[dict, TextContent, ImageContent, VideoContent, VideosContent, SearchResultsContent]
] = []
status: MsgStatus = MsgStatus.success
msg_id: str = Field(
Expand Down
2 changes: 2 additions & 0 deletions backend/director/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from director.agents.text_to_movie import TextToMovieAgent
from director.agents.meme_maker import MemeMakerAgent
from director.agents.composio import ComposioAgent
from director.agents.comparison import ComparisonAgent


from director.core.session import Session, InputMessage, MsgStatus
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(self, db, **kwargs):
TextToMovieAgent,
MemeMakerAgent,
ComposioAgent,
ComparisonAgent,
]

def add_videodb_state(self, session):
Expand Down
Loading

0 comments on commit 604df70

Please sign in to comment.