Skip to content

Commit

Permalink
feat: add og serving
Browse files Browse the repository at this point in the history
  • Loading branch information
imotai committed Oct 27, 2023
1 parent f46fa18 commit 37071e2
Show file tree
Hide file tree
Showing 11 changed files with 1,105 additions and 10 deletions.
16 changes: 12 additions & 4 deletions agent/src/og_agent/agent_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@
from fastapi.param_functions import Header, Annotated
from dotenv import dotenv_values

logger = logging.getLogger(__name__)

# the agent config
# the api server config
config = dotenv_values(".env")

LOG_LEVEL = (
logging.DEBUG if config.get("log_level", "info") == "debug" else logging.INFO
)
logging.basicConfig(
level=LOG_LEVEL,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)

app = FastAPI()
# the agent endpoint
listen_addr = "%s:%s" % (
Expand All @@ -32,7 +40,6 @@
)
if config.get("rpc_host", "") == "0.0.0.0":
listen_addr = "127.0.0.1:%s" % config.get("rpc_port", "9528")
logger.info(f"connect the agent server at {listen_addr}")
agent_sdk = AgentProxySDK(listen_addr)


Expand Down Expand Up @@ -191,6 +198,7 @@ async def process_task(


async def run_server():
logger.info(f"connect the agent server at {listen_addr}")
port = int(config.get("rpc_port", "9528")) + 1
server_config = uvicorn.Config(
app, host=config.get("rpc_host", "127.0.0.1"), port=port
Expand Down
172 changes: 171 additions & 1 deletion agent/src/og_agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,179 @@ class TypingState:


class BaseAgent:

def __init__(self, sdk):
self.kernel_sdk = sdk
self.model_name = ""

def _merge_delta_for_function_call(self, message, delta):
if len(message.keys()) == 0:
message.update(delta)
return
if "function_call" not in message:
message["function_call"] = delta["function_call"]
return
old_arguments = message["function_call"].get("arguments", "")
if delta["function_call"]["arguments"]:
message["function_call"]["arguments"] = (
old_arguments + delta["function_call"]["arguments"]
)

def _merge_delta_for_content(self, message, delta):
if not delta:
return
content = message.get("content", "")
if delta.get("content"):
message["content"] = content + delta["content"]

def _get_function_call_argument_new_typing(self, message):
if message["function_call"]["name"] == "python":
return TypingState.CODE, "", message["function_call"].get("arguments", "")

arguments = message["function_call"].get("arguments", "")
state = TypingState.START
explanation_str = ""
code_str = ""
for token_state, token in tokenize(io.StringIO(arguments)):
if token_state == None:
if state == TypingState.EXPLANATION and token[0] == 1:
explanation_str = token[1]
state = TypingState.START
if state == TypingState.CODE and token[0] == 1:
code_str = token[1]
state = TypingState.START
if token[1] == "explanation":
state = TypingState.EXPLANATION
if token[1] == "code":
state = TypingState.CODE
else:
# String
if token_state == 9 and state == TypingState.EXPLANATION:
explanation_str = "".join(token)
elif token_state == 9 and state == TypingState.CODE:
code_str = "".join(token)
return (state, explanation_str, code_str)

def _get_message_token_count(self, message):
response_token_count = 0
if "function_call" in message and message["function_call"]:
arguments = message["function_call"].get("arguments", "")
response_token_count += len(encoding.encode(arguments))
if "content" in message and message["content"]:
response_token_count += len(encoding.encode(message.get("content")))
return response_token_count

async def extract_message_for_json(self, response_generator,
messages, queue, rpc_context, task_context, task_opt):

async def extract_message(self, response_generator,
messages,
queue,
rpc_context,
task_context, task_opt):
"""
extract the messages from the response generator
"""
input_token_count = 0
for message in messages:
if not message["content"]:
continue
input_token_count += len(encoding.encode(message["content"]))
task_context.input_token_count += input_token_count
message = {}
text_content = ""
code_content = ""
context_output_token_count = task_context.output_token_count
async for chunk in response_generator:
if rpc_context.done():
logger.debug("the client has cancelled the request")
break
if not chunk["choices"]:
continue
task_context.llm_name = chunk.get("model", "")
self.model_name = chunk.get("model", "")
delta = chunk["choices"][0]["delta"]
if "function_call" in delta:
self._merge_delta_for_function_call(message, delta)
response_token_count = self._get_message_token_count(message)
task_context.output_token_count = (
response_token_count + context_output_token_count
)
task_context.llm_response_duration += int(
(time.time() - start_time) * 1000
)
start_time = time.time()
(
state,
explanation_str,
code_str,
) = self._get_function_call_argument_new_typing(message)
logger.debug(
f"argument explanation:{explanation_str} code:{code_str} text_content:{text_content}"
)
if explanation_str and text_content != explanation_str:
typed_chars = explanation_str[len(text_content) :]
text_content = explanation_str
if task_opt.streaming and len(typed_chars) > 0:
task_response = TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeText,
typing_content=TypingContent(
content=typed_chars, language="text"
),
)
await queue.put(task_response)
if code_str and code_content != code_str:
typed_chars = code_str[len(code_content) :]
code_content = code_str
if task_opt.streaming and len(typed_chars) > 0:
typing_language = "text"
if delta["function_call"].get("name", "") in [
"execute_python_code",
"python",
]:
typing_language = "python"
elif (
delta["function_call"].get("name", "")
== "execute_bash_code"
):
typing_language = "bash"
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeCode,
typing_content=TypingContent(
content=typed_chars, language=typing_language
),
)
)
else:
self._merge_delta_for_content(message, delta)
task_context.llm_response_duration += int(
(time.time() - start_time) * 1000
)
start_time = time.time()
if message.get("content") != None:
response_token_count = self._get_message_token_count(message)
task_context.output_token_count = (
response_token_count + context_output_token_count
)
if task_opt.streaming and delta.get("content"):
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeText,
typing_content=TypingContent(
content=delta["content"], language="text"
),
)
)
logger.info(
f"call the {self.model_name} with input token {task_context.input_token_count} and output token count {task_context.output_token_count}"
)
return message




async def call_function(self, code, context, task_context=None):
"""
Expand Down
1 change: 0 additions & 1 deletion agent/src/og_agent/openai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@


class OpenaiAgent(BaseAgent):

def __init__(self, model, system_prompt, sdk, is_azure=True):
super().__init__(sdk)
self.model = model
Expand Down
2 changes: 1 addition & 1 deletion format.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#! /bin/sh
#
# format.sh
pyink agent kernel chat up sdk examples
pyink agent kernel chat up sdk examples serving
9 changes: 6 additions & 3 deletions memory/src/og_memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,25 @@
"""


# import the agent memory
from og_proto.memory_pb2 import AgentMemory
from og_proto.memory_pb2 import AgentMemory as AgentMemoryProto
from jinja2 import Environment
from jinja2.loaders import PackageLoader

env = Environment(loader=PackageLoader("og_memory", "template"))

context_tpl = env.get_template("agent.jinja")

def agent_memory_to_context(memory: AgentMemory):
def agent_memory_to_context(memory: AgentMemoryProto):
"""
Convert the agent memory to context
:param memory : AgentMemory
:return: string context for llm
"""
return context_tpl.render(prompt=memory.instruction, guides=memory.guide_memory)

class AgentMemory():
def __init__(self, path):
self.path = path


4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ tiktoken
fastapi
uvicorn
pytest-mock
pydantic-settings
sse-starlette
starlette-context
llama-cpp-python

1 change: 1 addition & 0 deletions serving/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
the serving module for octogen
32 changes: 32 additions & 0 deletions serving/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (C) 2023 dbpunk.com Author imotai <codego.me@gmail.com>
# SPDX-FileCopyrightText: 2023 imotai <jackwang@octogen.dev>
# SPDX-FileContributor: imotai
#
# SPDX-License-Identifier: Elastic-2.0

""" """
from setuptools import setup, find_packages

setup(
name="og_serving",
version="0.3.6",
description="Open source code interpreter agent",
author="imotai",
author_email="wangtaize@dbpunk.com",
url="https://github.com/dbpunk-labs/octogen",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
packages=[
"og_serving",
],
package_dir={
"og_serving": "src/og_serving",
},
install_requires=["fastapi", "pydantic_settings"],
package_data={},
entry_points={
"console_scripts": [
"og_serving_http_server = og_serving.http_serving:run_serving",
]
},
)
Empty file.
36 changes: 36 additions & 0 deletions serving/src/og_serving/http_serving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# vim:fenc=utf-8

# SPDX-FileCopyrightText: 2023 imotai <jackwang@octogen.dev>
# SPDX-FileContributor: imotai
#
# SPDX-License-Identifier: Elastic-2.0

""" """
import os
import sys
import uvicorn
import logging
from dotenv import dotenv_values
from .server_app import create_app, Settings

config = dotenv_values(".env")

settings = Settings(_env_file="model.env")
LOG_LEVEL = (
logging.DEBUG if config.get("log_level", "info") == "debug" else logging.INFO
)

logging.basicConfig(
level=LOG_LEVEL,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)

logger = logging.getLogger(__name__)

def run_serving():
app = create_app(settings)
host = config.get("host", "localhost")
port = int(config.get("port", "9517"))
logger.info(f"Starting serving at {host}:{port}")
uvicorn.run(app, host=host, port=port)
Loading

0 comments on commit 37071e2

Please sign in to comment.