Skip to content

Commit

Permalink
fix: fix llama agent bug
Browse files Browse the repository at this point in the history
  • Loading branch information
imotai committed Oct 27, 2023
1 parent 212950c commit 56a9114
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 17 deletions.
16 changes: 12 additions & 4 deletions agent/src/og_agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _parse_arguments(self, arguments, is_code=False):
state = TypingState.START
explanation_str = ""
code_str = ""
logger.debug(f"the arguments {arguments}")
for token_state, token in tokenize(io.StringIO(arguments)):
if token_state == None:
if state == TypingState.EXPLANATION and token[0] == 1:
Expand Down Expand Up @@ -149,7 +150,7 @@ async def _read_function_call_message(
async def _read_json_message(
self, message, queue, old_text_content, old_code_content, task_context, task_opt
):
arguments = messages.get("content", "")
arguments = message.get("content", "")
typing_language = "text"
if arguments.find("execute_python_code") >= 0:
typing_language = "python"
Expand All @@ -158,10 +159,10 @@ async def _read_json_message(

return await self._send_typing_message(
arguments,
queue,
old_text_content,
old_code_content,
typing_language,
queue,
task_context,
task_opt,
)
Expand Down Expand Up @@ -222,7 +223,7 @@ async def extract_message(
is_json_format=False,
):
"""
extract the messages from the response generator
extract the chunk from the response generator
"""
message = {}
text_content = ""
Expand All @@ -235,6 +236,7 @@ async def extract_message(
break
if not chunk["choices"]:
continue
logger.debug(f"the chunk {chunk}")
task_context.llm_name = chunk.get("model", "")
self.model_name = chunk.get("model", "")
delta = chunk["choices"][0]["delta"]
Expand Down Expand Up @@ -273,14 +275,20 @@ async def extract_message(
response_token_count + context_output_token_count
)
if is_json_format:
await self._read_json_message(
(
new_text_content,
new_code_content,
) = await self._read_json_message(
message,
queue,
text_content,
code_content,
task_context,
task_opt,
)
text_content = new_text_content
code_content = new_code_content

elif task_opt.streaming and delta.get("content"):
await queue.put(
TaskResponse(
Expand Down
4 changes: 2 additions & 2 deletions agent/src/og_agent/llama_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ async def call_llama(self, messages, queue, context, task_context, task_opt):
input_token_count += len(encoding.encode(message["content"]))
task_context.input_token_count += input_token_count
start_time = time.time()
response = self.client.chat(messages, "codellama")
response = self.client.chat(messages, "codellama", max_tokens=2048)
message = await self.extract_message(
response,
queue,
context,
task_context,
task_opt,
start_time,
is_json_format=False,
is_json_format=True,
)
return message

Expand Down
7 changes: 3 additions & 4 deletions agent/src/og_agent/llama_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def __init__(self, endpoint, key, grammar):
super().__init__(endpoint + "/v1/chat/completions", key)
self.grammar = grammar

async def chat(
self, messages, model, temperature=0, max_tokens=1024, stop=["</s>", "\n"]
):
async def chat(self, messages, model, temperature=0, max_tokens=1024, stop=[]):
data = {
"messages": messages,
"temperature": temperature,
Expand All @@ -31,8 +29,9 @@ async def chat(
"model": model,
"max_tokens": max_tokens,
"top_p": 0.9,
"stop": stop,
}
if stop:
data["stop"] = stop
async for line in self.arun(data):
if len(line) < 6:
continue
Expand Down
8 changes: 1 addition & 7 deletions agent/src/og_agent/openai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,7 @@ async def call_openai(self, messages, queue, context, task_context, task_opt):
stream=True,
)
message = await self.extract_message(
response,
queue,
context,
task_context,
task_opt,
start_time,
is_json_format=False,
response, queue, context, task_context, task_opt, start_time
)
return message

Expand Down
20 changes: 20 additions & 0 deletions serving/src/og_serving/http_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import logging
from dotenv import dotenv_values
from .server_app import create_app, Settings
from llama_cpp.llama_chat_format import register_chat_format, ChatFormatterResponse, _map_roles, _format_add_colon_single
from llama_cpp import llama_types
from typing import Any, List

config = dotenv_values(".env")

Expand All @@ -29,6 +32,23 @@
logger = logging.getLogger(__name__)


@register_chat_format("phind")
def format_phind(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
_roles = dict(user="### User Message", assistant="### Assistant")
_sep = "\n\n"
_system_message = "### System Prompt\nYou are an intelligent programming assistant."
for message in messages:
if message["role"] == "system" and message["content"]:
_system_message = f"""### System Prompt\n{message['content']}"""
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
_prompt = _format_add_colon_single(_system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt)


def run_serving():
app = create_app(settings)
host = config.get("host", "localhost")
Expand Down

0 comments on commit 56a9114

Please sign in to comment.