diff --git a/agent/src/og_agent/agent_api_server.py b/agent/src/og_agent/agent_api_server.py index 64016e2..abd9a56 100644 --- a/agent/src/og_agent/agent_api_server.py +++ b/agent/src/og_agent/agent_api_server.py @@ -19,11 +19,19 @@ from fastapi.param_functions import Header, Annotated from dotenv import dotenv_values -logger = logging.getLogger(__name__) - -# the agent config +# the api server config config = dotenv_values(".env") +LOG_LEVEL = ( + logging.DEBUG if config.get("log_level", "info") == "debug" else logging.INFO +) +logging.basicConfig( + level=LOG_LEVEL, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) + app = FastAPI() # the agent endpoint listen_addr = "%s:%s" % ( @@ -32,7 +40,6 @@ ) if config.get("rpc_host", "") == "0.0.0.0": listen_addr = "127.0.0.1:%s" % config.get("rpc_port", "9528") -logger.info(f"connect the agent server at {listen_addr}") agent_sdk = AgentProxySDK(listen_addr) @@ -191,6 +198,7 @@ async def process_task( async def run_server(): + logger.info(f"connect the agent server at {listen_addr}") port = int(config.get("rpc_port", "9528")) + 1 server_config = uvicorn.Config( app, host=config.get("rpc_host", "127.0.0.1"), port=port diff --git a/agent/src/og_agent/base_agent.py b/agent/src/og_agent/base_agent.py index 1c7a1a0..b42131f 100644 --- a/agent/src/og_agent/base_agent.py +++ b/agent/src/og_agent/base_agent.py @@ -52,9 +52,179 @@ class TypingState: class BaseAgent: - def __init__(self, sdk): self.kernel_sdk = sdk + self.model_name = "" + + def _merge_delta_for_function_call(self, message, delta): + if len(message.keys()) == 0: + message.update(delta) + return + if "function_call" not in message: + message["function_call"] = delta["function_call"] + return + old_arguments = message["function_call"].get("arguments", "") + if delta["function_call"]["arguments"]: + message["function_call"]["arguments"] = ( + old_arguments + delta["function_call"]["arguments"] + ) + + def _merge_delta_for_content(self, message, delta): + if not delta: + return + content = message.get("content", "") + if delta.get("content"): + message["content"] = content + delta["content"] + + def _get_function_call_argument_new_typing(self, message): + if message["function_call"]["name"] == "python": + return TypingState.CODE, "", message["function_call"].get("arguments", "") + + arguments = message["function_call"].get("arguments", "") + state = TypingState.START + explanation_str = "" + code_str = "" + for token_state, token in tokenize(io.StringIO(arguments)): + if token_state == None: + if state == TypingState.EXPLANATION and token[0] == 1: + explanation_str = token[1] + state = TypingState.START + if state == TypingState.CODE and token[0] == 1: + code_str = token[1] + state = TypingState.START + if token[1] == "explanation": + state = TypingState.EXPLANATION + if token[1] == "code": + state = TypingState.CODE + else: + # String + if token_state == 9 and state == TypingState.EXPLANATION: + explanation_str = "".join(token) + elif token_state == 9 and state == TypingState.CODE: + code_str = "".join(token) + return (state, explanation_str, code_str) + + def _get_message_token_count(self, message): + response_token_count = 0 + if "function_call" in message and message["function_call"]: + arguments = message["function_call"].get("arguments", "") + response_token_count += len(encoding.encode(arguments)) + if "content" in message and message["content"]: + response_token_count += len(encoding.encode(message.get("content"))) + return response_token_count + + async def extract_message_for_json(self, response_generator, + messages, queue, rpc_context, task_context, task_opt): + + async def extract_message(self, response_generator, + messages, + queue, + rpc_context, + task_context, task_opt): + """ + extract the messages from the response generator + """ + input_token_count = 0 + for message in messages: + if not message["content"]: + continue + input_token_count += len(encoding.encode(message["content"])) + task_context.input_token_count += input_token_count + message = {} + text_content = "" + code_content = "" + context_output_token_count = task_context.output_token_count + async for chunk in response_generator: + if rpc_context.done(): + logger.debug("the client has cancelled the request") + break + if not chunk["choices"]: + continue + task_context.llm_name = chunk.get("model", "") + self.model_name = chunk.get("model", "") + delta = chunk["choices"][0]["delta"] + if "function_call" in delta: + self._merge_delta_for_function_call(message, delta) + response_token_count = self._get_message_token_count(message) + task_context.output_token_count = ( + response_token_count + context_output_token_count + ) + task_context.llm_response_duration += int( + (time.time() - start_time) * 1000 + ) + start_time = time.time() + ( + state, + explanation_str, + code_str, + ) = self._get_function_call_argument_new_typing(message) + logger.debug( + f"argument explanation:{explanation_str} code:{code_str} text_content:{text_content}" + ) + if explanation_str and text_content != explanation_str: + typed_chars = explanation_str[len(text_content) :] + text_content = explanation_str + if task_opt.streaming and len(typed_chars) > 0: + task_response = TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnModelTypeText, + typing_content=TypingContent( + content=typed_chars, language="text" + ), + ) + await queue.put(task_response) + if code_str and code_content != code_str: + typed_chars = code_str[len(code_content) :] + code_content = code_str + if task_opt.streaming and len(typed_chars) > 0: + typing_language = "text" + if delta["function_call"].get("name", "") in [ + "execute_python_code", + "python", + ]: + typing_language = "python" + elif ( + delta["function_call"].get("name", "") + == "execute_bash_code" + ): + typing_language = "bash" + await queue.put( + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnModelTypeCode, + typing_content=TypingContent( + content=typed_chars, language=typing_language + ), + ) + ) + else: + self._merge_delta_for_content(message, delta) + task_context.llm_response_duration += int( + (time.time() - start_time) * 1000 + ) + start_time = time.time() + if message.get("content") != None: + response_token_count = self._get_message_token_count(message) + task_context.output_token_count = ( + response_token_count + context_output_token_count + ) + if task_opt.streaming and delta.get("content"): + await queue.put( + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnModelTypeText, + typing_content=TypingContent( + content=delta["content"], language="text" + ), + ) + ) + logger.info( + f"call the {self.model_name} with input token {task_context.input_token_count} and output token count {task_context.output_token_count}" + ) + return message + + + async def call_function(self, code, context, task_context=None): """ diff --git a/agent/src/og_agent/openai_agent.py b/agent/src/og_agent/openai_agent.py index fc506ff..0d0b443 100644 --- a/agent/src/og_agent/openai_agent.py +++ b/agent/src/og_agent/openai_agent.py @@ -68,7 +68,6 @@ class OpenaiAgent(BaseAgent): - def __init__(self, model, system_prompt, sdk, is_azure=True): super().__init__(sdk) self.model = model diff --git a/format.sh b/format.sh index 08b5695..a09f041 100644 --- a/format.sh +++ b/format.sh @@ -1,4 +1,4 @@ #! /bin/sh # # format.sh -pyink agent kernel chat up sdk examples +pyink agent kernel chat up sdk examples serving diff --git a/memory/src/og_memory/memory.py b/memory/src/og_memory/memory.py index 1b2e6b6..ff4e16a 100644 --- a/memory/src/og_memory/memory.py +++ b/memory/src/og_memory/memory.py @@ -9,16 +9,16 @@ """ - # import the agent memory -from og_proto.memory_pb2 import AgentMemory +from og_proto.memory_pb2 import AgentMemory as AgentMemoryProto from jinja2 import Environment from jinja2.loaders import PackageLoader env = Environment(loader=PackageLoader("og_memory", "template")) + context_tpl = env.get_template("agent.jinja") -def agent_memory_to_context(memory: AgentMemory): +def agent_memory_to_context(memory: AgentMemoryProto): """ Convert the agent memory to context :param memory : AgentMemory @@ -26,5 +26,8 @@ def agent_memory_to_context(memory: AgentMemory): """ return context_tpl.render(prompt=memory.instruction, guides=memory.guide_memory) +class AgentMemory(): + def __init__(self, path): + self.path = path diff --git a/requirements.txt b/requirements.txt index 320fcb9..5196778 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,4 +21,8 @@ tiktoken fastapi uvicorn pytest-mock +pydantic-settings +sse-starlette +starlette-context +llama-cpp-python diff --git a/serving/README.md b/serving/README.md new file mode 100644 index 0000000..a26c71b --- /dev/null +++ b/serving/README.md @@ -0,0 +1 @@ +the serving module for octogen diff --git a/serving/setup.py b/serving/setup.py new file mode 100644 index 0000000..5fbded6 --- /dev/null +++ b/serving/setup.py @@ -0,0 +1,32 @@ +# Copyright (C) 2023 dbpunk.com Author imotai +# SPDX-FileCopyrightText: 2023 imotai +# SPDX-FileContributor: imotai +# +# SPDX-License-Identifier: Elastic-2.0 + +""" """ +from setuptools import setup, find_packages + +setup( + name="og_serving", + version="0.3.6", + description="Open source code interpreter agent", + author="imotai", + author_email="wangtaize@dbpunk.com", + url="https://github.com/dbpunk-labs/octogen", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + packages=[ + "og_serving", + ], + package_dir={ + "og_serving": "src/og_serving", + }, + install_requires=["fastapi", "pydantic_settings"], + package_data={}, + entry_points={ + "console_scripts": [ + "og_serving_http_server = og_serving.http_serving:run_serving", + ] + }, +) diff --git a/serving/src/og_serving/__init__.py b/serving/src/og_serving/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/serving/src/og_serving/http_serving.py b/serving/src/og_serving/http_serving.py new file mode 100644 index 0000000..af091fa --- /dev/null +++ b/serving/src/og_serving/http_serving.py @@ -0,0 +1,36 @@ +# vim:fenc=utf-8 + +# SPDX-FileCopyrightText: 2023 imotai +# SPDX-FileContributor: imotai +# +# SPDX-License-Identifier: Elastic-2.0 + +""" """ +import os +import sys +import uvicorn +import logging +from dotenv import dotenv_values +from .server_app import create_app, Settings + +config = dotenv_values(".env") + +settings = Settings(_env_file="model.env") +LOG_LEVEL = ( + logging.DEBUG if config.get("log_level", "info") == "debug" else logging.INFO +) + +logging.basicConfig( + level=LOG_LEVEL, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) + +logger = logging.getLogger(__name__) + +def run_serving(): + app = create_app(settings) + host = config.get("host", "localhost") + port = int(config.get("port", "9517")) + logger.info(f"Starting serving at {host}:{port}") + uvicorn.run(app, host=host, port=port) diff --git a/serving/src/og_serving/server_app.py b/serving/src/og_serving/server_app.py new file mode 100644 index 0000000..998f8b0 --- /dev/null +++ b/serving/src/og_serving/server_app.py @@ -0,0 +1,842 @@ +# vim:fenc=utf-8 + +# SPDX-FileCopyrightText: 2023 imotai +# SPDX-FileContributor: imotai +# +# SPDX-License-Identifier: Elastic-2.0 + +""" + +""" +import sys +import json +import traceback +import multiprocessing +import time +from re import compile, Match, Pattern +from threading import Lock +from functools import partial +from typing import Callable, Coroutine, Iterator, List, Optional, Tuple, Union, Dict +from typing_extensions import TypedDict, Literal + +import llama_cpp + +from llama_cpp.llama_grammar import LlamaGrammar +import anyio +from anyio.streams.memory import MemoryObjectSendStream +from starlette.concurrency import run_in_threadpool, iterate_in_threadpool +from fastapi import Depends, FastAPI, APIRouter, Request, Response +from fastapi.middleware import Middleware +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.routing import APIRoute +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings +from sse_starlette.sse import EventSourceResponse +from starlette_context import plugins +from starlette_context.middleware import RawContextMiddleware + +import numpy as np +import numpy.typing as npt + + +# Disable warning for model and model_alias settings +BaseSettings.model_config['protected_namespaces'] = () + + +class Settings(BaseSettings): + model: str = Field( + description="The path to the model to use for generating completions." + ) + model_alias: Optional[str] = Field( + default=None, + description="The alias of the model to use for generating completions.", + ) + seed: int = Field(default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random.") + n_ctx: int = Field(default=2048, ge=1, description="The context size.") + n_batch: int = Field( + default=512, ge=1, description="The batch size to use per eval." + ) + n_gpu_layers: int = Field( + default=0, + ge=-1, + description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.", + ) + main_gpu: int = Field( + default=0, + ge=0, + description="Main GPU to use.", + ) + tensor_split: Optional[List[float]] = Field( + default=None, + description="Split layers across multiple GPUs in proportion.", + ) + rope_freq_base: float = Field( + default=0.0, description="RoPE base frequency" + ) + rope_freq_scale: float = Field( + default=0.0, description="RoPE frequency scaling factor" + ) + mul_mat_q: bool = Field( + default=True, description="if true, use experimental mul_mat_q kernels" + ) + f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.") + logits_all: bool = Field(default=True, description="Whether to return logits.") + vocab_only: bool = Field( + default=False, description="Whether to only return the vocabulary." + ) + use_mmap: bool = Field( + default=llama_cpp.llama_mmap_supported(), + description="Use mmap.", + ) + use_mlock: bool = Field( + default=llama_cpp.llama_mlock_supported(), + description="Use mlock.", + ) + embedding: bool = Field(default=True, description="Whether to use embeddings.") + n_threads: int = Field( + default=max(multiprocessing.cpu_count() // 2, 1), + ge=1, + description="The number of threads to use.", + ) + last_n_tokens_size: int = Field( + default=64, + ge=0, + description="Last n tokens to keep for repeat penalty calculation.", + ) + lora_base: Optional[str] = Field( + default=None, + description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model." + ) + lora_path: Optional[str] = Field( + default=None, + description="Path to a LoRA file to apply to the model.", + ) + numa: bool = Field( + default=False, + description="Enable NUMA support.", + ) + chat_format: str = Field( + default="llama-2", + description="Chat format to use.", + ) + cache: bool = Field( + default=False, + description="Use a cache to reduce processing times for evaluated prompts.", + ) + cache_type: Literal["ram", "disk"] = Field( + default="ram", + description="The type of cache to use. Only used if cache is True.", + ) + cache_size: int = Field( + default=2 << 30, + description="The size of the cache in bytes. Only used if cache is True.", + ) + verbose: bool = Field( + default=True, description="Whether to print debug information." + ) + host: str = Field(default="localhost", description="Listen address") + port: int = Field(default=8000, description="Listen port") + interrupt_requests: bool = Field( + default=True, + description="Whether to interrupt requests when a new request is received.", + ) + + +class ErrorResponse(TypedDict): + """OpenAI style error response""" + + message: str + type: str + param: Optional[str] + code: Optional[str] + + +class ErrorResponseFormatters: + """Collection of formatters for error responses. + + Args: + request (Union[CreateCompletionRequest, CreateChatCompletionRequest]): + Request body + match (Match[str]): Match object from regex pattern + + Returns: + Tuple[int, ErrorResponse]: Status code and error response + """ + + @staticmethod + def context_length_exceeded( + request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], + match, # type: Match[str] # type: ignore + ) -> Tuple[int, ErrorResponse]: + """Formatter for context length exceeded error""" + + context_window = int(match.group(2)) + prompt_tokens = int(match.group(1)) + completion_tokens = request.max_tokens + if hasattr(request, "messages"): + # Chat completion + message = ( + "This model's maximum context length is {} tokens. " + "However, you requested {} tokens " + "({} in the messages, {} in the completion). " + "Please reduce the length of the messages or completion." + ) + else: + # Text completion + message = ( + "This model's maximum context length is {} tokens, " + "however you requested {} tokens " + "({} in your prompt; {} for the completion). " + "Please reduce your prompt; or completion length." + ) + return 400, ErrorResponse( + message=message.format( + context_window, + completion_tokens + prompt_tokens, + prompt_tokens, + completion_tokens, + ), + type="invalid_request_error", + param="messages", + code="context_length_exceeded", + ) + + @staticmethod + def model_not_found( + request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], + match, # type: Match[str] # type: ignore + ) -> Tuple[int, ErrorResponse]: + """Formatter for model_not_found error""" + + model_path = str(match.group(1)) + message = f"The model `{model_path}` does not exist" + return 400, ErrorResponse( + message=message, + type="invalid_request_error", + param=None, + code="model_not_found", + ) + + +class RouteErrorHandler(APIRoute): + """Custom APIRoute that handles application errors and exceptions""" + + # key: regex pattern for original error message from llama_cpp + # value: formatter function + pattern_and_formatters: Dict[ + "Pattern", + Callable[ + [ + Union["CreateCompletionRequest", "CreateChatCompletionRequest"], + "Match[str]", + ], + Tuple[int, ErrorResponse], + ], + ] = { + compile( + r"Requested tokens \((\d+)\) exceed context window of (\d+)" + ): ErrorResponseFormatters.context_length_exceeded, + compile( + r"Model path does not exist: (.+)" + ): ErrorResponseFormatters.model_not_found, + } + + def error_message_wrapper( + self, + error: Exception, + body: Optional[ + Union[ + "CreateChatCompletionRequest", + "CreateCompletionRequest", + "CreateEmbeddingRequest", + ] + ] = None, + ) -> Tuple[int, ErrorResponse]: + """Wraps error message in OpenAI style error response""" + print(f"Exception: {str(error)}", file=sys.stderr) + traceback.print_exc(file=sys.stderr) + if body is not None and isinstance( + body, + ( + CreateCompletionRequest, + CreateChatCompletionRequest, + ), + ): + # When text completion or chat completion + for pattern, callback in self.pattern_and_formatters.items(): + match = pattern.search(str(error)) + if match is not None: + return callback(body, match) + + # Wrap other errors as internal server error + return 500, ErrorResponse( + message=str(error), + type="internal_server_error", + param=None, + code=None, + ) + + def get_route_handler( + self, + ) -> Callable[[Request], Coroutine[None, None, Response]]: + """Defines custom route handler that catches exceptions and formats + in OpenAI style error response""" + + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + try: + start_sec = time.perf_counter() + response = await original_route_handler(request) + elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000) + response.headers["openai-processing-ms"] = f"{elapsed_time_ms}" + return response + except Exception as exc: + json_body = await request.json() + try: + if "messages" in json_body: + # Chat completion + body: Optional[ + Union[ + CreateChatCompletionRequest, + CreateCompletionRequest, + CreateEmbeddingRequest, + ] + ] = CreateChatCompletionRequest(**json_body) + elif "prompt" in json_body: + # Text completion + body = CreateCompletionRequest(**json_body) + else: + # Embedding + body = CreateEmbeddingRequest(**json_body) + except Exception: + # Invalid request body + body = None + + # Get proper error message from the exception + ( + status_code, + error_message, + ) = self.error_message_wrapper(error=exc, body=body) + return JSONResponse( + {"error": error_message}, + status_code=status_code, + ) + + return custom_route_handler + + +router = APIRouter(route_class=RouteErrorHandler) + +settings: Optional[Settings] = None +llama: Optional[llama_cpp.Llama] = None + + +def create_app(settings: Optional[Settings] = None): + if settings is None: + settings = Settings() + + middleware = [ + Middleware(RawContextMiddleware, plugins=(plugins.RequestIdPlugin(),)) + ] + app = FastAPI( + middleware=middleware, + title="🦙 llama.cpp Python API", + version="0.0.1", + ) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + app.include_router(router) + global llama + llama = llama_cpp.Llama( + model_path=settings.model, + seed=settings.seed, + n_ctx=settings.n_ctx, + n_batch=settings.n_batch, + n_gpu_layers=settings.n_gpu_layers, + main_gpu=settings.main_gpu, + tensor_split=settings.tensor_split, + rope_freq_base=settings.rope_freq_base, + rope_freq_scale=settings.rope_freq_scale, + mul_mat_q=settings.mul_mat_q, + f16_kv=settings.f16_kv, + logits_all=settings.logits_all, + vocab_only=settings.vocab_only, + use_mmap=settings.use_mmap, + use_mlock=settings.use_mlock, + embedding=settings.embedding, + n_threads=settings.n_threads, + last_n_tokens_size=settings.last_n_tokens_size, + lora_base=settings.lora_base, + lora_path=settings.lora_path, + numa=settings.numa, + chat_format=settings.chat_format, + verbose=settings.verbose, + ) + if settings.cache: + if settings.cache_type == "disk": + if settings.verbose: + print(f"Using disk cache with size {settings.cache_size}") + cache = llama_cpp.LlamaDiskCache(capacity_bytes=settings.cache_size) + else: + if settings.verbose: + print(f"Using ram cache with size {settings.cache_size}") + cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size) + + cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size) + llama.set_cache(cache) + + def set_settings(_settings: Settings): + global settings + settings = _settings + + set_settings(settings) + return app + + +llama_outer_lock = Lock() +llama_inner_lock = Lock() + + +def get_llama(): + # NOTE: This double lock allows the currently streaming llama model to + # check if any other requests are pending in the same thread and cancel + # the stream if so. + llama_outer_lock.acquire() + release_outer_lock = True + try: + llama_inner_lock.acquire() + try: + llama_outer_lock.release() + release_outer_lock = False + yield llama + finally: + llama_inner_lock.release() + finally: + if release_outer_lock: + llama_outer_lock.release() + + +def get_settings(): + yield settings + + +async def get_event_publisher( + request: Request, + inner_send_chan: MemoryObjectSendStream, + iterator: Iterator, +): + async with inner_send_chan: + try: + async for chunk in iterate_in_threadpool(iterator): + await inner_send_chan.send(dict(data=json.dumps(chunk))) + if await request.is_disconnected(): + raise anyio.get_cancelled_exc_class()() + if settings.interrupt_requests and llama_outer_lock.locked(): + await inner_send_chan.send(dict(data="[DONE]")) + raise anyio.get_cancelled_exc_class()() + await inner_send_chan.send(dict(data="[DONE]")) + except anyio.get_cancelled_exc_class() as e: + print("disconnected") + with anyio.move_on_after(1, shield=True): + print(f"Disconnected from client (via refresh/close) {request.client}") + raise e + + +model_field = Field( + description="The model to use for generating completions.", default=None +) + +max_tokens_field = Field( + default=16, ge=1, description="The maximum number of tokens to generate." +) + +temperature_field = Field( + default=0.8, + ge=0.0, + le=2.0, + description="Adjust the randomness of the generated text.\n\n" + + "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.", +) + +top_p_field = Field( + default=0.95, + ge=0.0, + le=1.0, + description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n" + + "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.", +) + +stop_field = Field( + default=None, + description="A list of tokens at which to stop generation. If None, no stop tokens are used.", +) + +stream_field = Field( + default=False, + description="Whether to stream the results as they are generated. Useful for chatbots.", +) + +top_k_field = Field( + default=40, + ge=0, + description="Limit the next token selection to the K most probable tokens.\n\n" + + "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.", +) + +repeat_penalty_field = Field( + default=1.1, + ge=0.0, + description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n" + + "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.", +) + +presence_penalty_field = Field( + default=0.0, + ge=-2.0, + le=2.0, + description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", +) + +frequency_penalty_field = Field( + default=0.0, + ge=-2.0, + le=2.0, + description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", +) + +mirostat_mode_field = Field( + default=0, + ge=0, + le=2, + description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)", +) + +mirostat_tau_field = Field( + default=5.0, + ge=0.0, + le=10.0, + description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text", +) + +mirostat_eta_field = Field( + default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate" +) + + +class CreateCompletionRequest(BaseModel): + prompt: Union[str, List[str]] = Field( + default="", description="The prompt to generate completions for." + ) + suffix: Optional[str] = Field( + default=None, + description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.", + ) + max_tokens: int = max_tokens_field + temperature: float = temperature_field + top_p: float = top_p_field + mirostat_mode: int = mirostat_mode_field + mirostat_tau: float = mirostat_tau_field + mirostat_eta: float = mirostat_eta_field + echo: bool = Field( + default=False, + description="Whether to echo the prompt in the generated text. Useful for chatbots.", + ) + stop: Optional[Union[str, List[str]]] = stop_field + stream: bool = stream_field + logprobs: Optional[int] = Field( + default=None, + ge=0, + description="The number of logprobs to generate. If None, no logprobs are generated.", + ) + presence_penalty: Optional[float] = presence_penalty_field + frequency_penalty: Optional[float] = frequency_penalty_field + logit_bias: Optional[Dict[str, float]] = Field(None) + logprobs: Optional[int] = Field(None) + + # ignored or currently unsupported + model: Optional[str] = model_field + n: Optional[int] = 1 + best_of: Optional[int] = 1 + user: Optional[str] = Field(default=None) + + # llama.cpp specific parameters + top_k: int = top_k_field + repeat_penalty: float = repeat_penalty_field + logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) + grammar: str = Field(default=None) + model_config = { + "json_schema_extra": { + "examples": [ + { + "prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n", + "stop": ["\n", "###"], + } + ] + } + } + +def make_logit_bias_processor( + llama: llama_cpp.Llama, + logit_bias: Dict[str, float], + logit_bias_type: Optional[Literal["input_ids", "tokens"]], +): + if logit_bias_type is None: + logit_bias_type = "input_ids" + + to_bias: Dict[int, float] = {} + if logit_bias_type == "input_ids": + for input_id, score in logit_bias.items(): + input_id = int(input_id) + to_bias[input_id] = score + + elif logit_bias_type == "tokens": + for token, score in logit_bias.items(): + token = token.encode("utf-8") + for input_id in llama.tokenize(token, add_bos=False): + to_bias[input_id] = score + + def logit_bias_processor( + input_ids: npt.NDArray[np.intc], + scores: npt.NDArray[np.single], + ) -> npt.NDArray[np.single]: + new_scores = [None] * len(scores) + for input_id, score in enumerate(scores): + new_scores[input_id] = score + to_bias.get(input_id, 0.0) + + return new_scores + + return logit_bias_processor + + +@router.post( + "/v1/completions", +) +@router.post("/v1/engines/copilot-codex/completions") +async def create_completion( + request: Request, + body: CreateCompletionRequest, + llama: llama_cpp.Llama = Depends(get_llama), +) -> llama_cpp.Completion: + if isinstance(body.prompt, list): + assert len(body.prompt) <= 1 + body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" + + exclude = { + "n", + "best_of", + "logit_bias", + "logit_bias_type", + "user", + } + kwargs = body.model_dump(exclude=exclude) + + if body.logit_bias is not None: + kwargs["logits_processor"] = llama_cpp.LogitsProcessorList( + [ + make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), + ] + ) + + iterator_or_completion: Union[ + llama_cpp.Completion, Iterator[llama_cpp.CompletionChunk] + ] = await run_in_threadpool(llama, **kwargs) + + if isinstance(iterator_or_completion, Iterator): + # EAFP: It's easier to ask for forgiveness than permission + first_response = await run_in_threadpool(next, iterator_or_completion) + + # If no exception was raised from first_response, we can assume that + # the iterator is valid and we can use it to stream the response. + def iterator() -> Iterator[llama_cpp.CompletionChunk]: + yield first_response + yield from iterator_or_completion + + send_chan, recv_chan = anyio.create_memory_object_stream(10) + return EventSourceResponse( + recv_chan, + data_sender_callable=partial( # type: ignore + get_event_publisher, + request=request, + inner_send_chan=send_chan, + iterator=iterator(), + ), + ) + else: + return iterator_or_completion + + +class CreateEmbeddingRequest(BaseModel): + model: Optional[str] = model_field + input: Union[str, List[str]] = Field(description="The input to embed.") + user: Optional[str] = Field(default=None) + + model_config = { + "json_schema_extra": { + "examples": [ + { + "input": "The food was delicious and the waiter...", + } + ] + } + } + + +@router.post( + "/v1/embeddings", +) +async def create_embedding( + request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) +): + return await run_in_threadpool( + llama.create_embedding, **request.model_dump(exclude={"user"}) + ) + + +class ChatCompletionRequestMessage(BaseModel): + role: Literal["system", "user", "assistant"] = Field( + default="user", description="The role of the message." + ) + content: str = Field(default="", description="The content of the message.") + + +class CreateChatCompletionRequest(BaseModel): + messages: List[ChatCompletionRequestMessage] = Field( + default=[], description="A list of messages to generate completions for." + ) + functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field( + default=None, + description="A list of functions to apply to the generated completions.", + ) + function_call: Optional[Union[str, llama_cpp.ChatCompletionFunctionCall]] = Field( + default=None, + description="A function to apply to the generated completions.", + ) + max_tokens: int = max_tokens_field + temperature: float = temperature_field + top_p: float = top_p_field + mirostat_mode: int = mirostat_mode_field + mirostat_tau: float = mirostat_tau_field + mirostat_eta: float = mirostat_eta_field + stop: Optional[List[str]] = stop_field + stream: bool = stream_field + presence_penalty: Optional[float] = presence_penalty_field + frequency_penalty: Optional[float] = frequency_penalty_field + logit_bias: Optional[Dict[str, float]] = Field(None) + + # ignored or currently unsupported + model: Optional[str] = model_field + n: Optional[int] = 1 + user: Optional[str] = Field(None) + + # llama.cpp specific parameters + top_k: int = top_k_field + repeat_penalty: float = repeat_penalty_field + logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) + + model_config = { + "json_schema_extra": { + "examples": [ + { + "messages": [ + ChatCompletionRequestMessage( + role="system", content="You are a helpful assistant." + ).model_dump(), + ChatCompletionRequestMessage( + role="user", content="What is the capital of France?" + ).model_dump(), + ] + } + ] + } + } + + +@router.post( + "/v1/chat/completions", +) +async def create_chat_completion( + request: Request, + body: CreateChatCompletionRequest, + llama: llama_cpp.Llama = Depends(get_llama), + settings: Settings = Depends(get_settings), +) -> llama_cpp.ChatCompletion: + exclude = { + "n", + "logit_bias", + "logit_bias_type", + "user", + } + kwargs = body.model_dump(exclude=exclude) + if 'grammar'in kwargs['grammar'] and kwargs['grammar']: + kwargs['grammar'] = LlamaGrammar.from_string(kwargs['grammar']) + if body.logit_bias is not None: + kwargs["logits_processor"] = llama_cpp.LogitsProcessorList( + [ + make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), + ] + ) + iterator_or_completion: Union[ + llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk] + ] = await run_in_threadpool(llama.create_chat_completion, **kwargs) + + if isinstance(iterator_or_completion, Iterator): + # EAFP: It's easier to ask for forgiveness than permission + first_response = await run_in_threadpool(next, iterator_or_completion) + + # If no exception was raised from first_response, we can assume that + # the iterator is valid and we can use it to stream the response. + def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: + yield first_response + yield from iterator_or_completion + + send_chan, recv_chan = anyio.create_memory_object_stream(10) + return EventSourceResponse( + recv_chan, + data_sender_callable=partial( # type: ignore + get_event_publisher, + request=request, + inner_send_chan=send_chan, + iterator=iterator(), + ), + ) + else: + return iterator_or_completion + + +class ModelData(TypedDict): + id: str + object: Literal["model"] + owned_by: str + permissions: List[str] + + +class ModelList(TypedDict): + object: Literal["list"] + data: List[ModelData] + + +@router.get("/v1/models") +async def get_models( + settings: Settings = Depends(get_settings), +) -> ModelList: + assert llama is not None + return { + "object": "list", + "data": [ + { + "id": settings.model_alias + if settings.model_alias is not None + else llama.model_path, + "object": "model", + "owned_by": "me", + "permissions": [], + } + ], + }