diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 34f31b4137..ffe0440b03 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -7,16 +7,24 @@ import os import random import time +from multiprocessing import Process from queue import Queue -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union import numpy as np +import zmq.asyncio from tqdm import tqdm from lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig) from lmdeploy.pytorch.engine import EngineInstance +from lmdeploy.serve.tokenization import (DeTokenizeInput, DeTokenizeOutput, + ProcessArgs, ProcessOutput, + TokenizeInput, TokenizeOutput, + get_zmq_socket, + run_detokenize_process, + run_tokenize_process) from lmdeploy.tokenizer import DetokenizeState, Tokenizer from lmdeploy.utils import get_logger @@ -87,6 +95,33 @@ def __init__(self, model_path: str, self.csv = csv self.pbar = None + async def _tokenize_one_request(self, input: TokenizeInput): + self.tokenize_state[input.session_id] = ProcessOutput( + None, asyncio.Event()) + await self.send_to_tokenizer.send_pyobj(input) + while not self.tokenize_state[input.session_id].event.is_set(): + await self._tokenize_step() + return self.tokenize_state[input.session_id].result.input_ids + + async def _tokenize_step(self): + result: TokenizeOutput = await self.recv_from_tokenizer.recv_pyobj() + self.tokenize_state[result.session_id].result = result + self.tokenize_state[result.session_id].event.set() + + async def _detokenize_one_request(self, input: DeTokenizeInput): + self.detokenize_state[input.session_id] = ProcessOutput( + None, asyncio.Event()) + await self.send_to_detokenizer.send_pyobj(input) + if not input.sequence_start: + while not self.detokenize_state[input.session_id].event.is_set(): + await self._detokenize_step() + return self.detokenize_state[input.session_id].result.response + + async def _detokenize_step(self): + result: DeTokenizeOutput = await self.recv_from_detokenizer.recv_pyobj() + self.detokenize_state[result.session_id].result = result + self.detokenize_state[result.session_id].event.set() + async def _inference(self, req_queue: Queue, session_id: int, temperature: float, top_p: float, top_k: int, stream_output: bool, skip_tokenize: bool, @@ -102,13 +137,19 @@ async def _inference(self, req_queue: Queue, session_id: int, if skip_tokenize: input_ids = prompt else: - input_ids = self.tokenizer(prompt).input_ids + input_ids = await self._tokenize_one_request( + TokenizeInput(session_id, prompt, True)) + # input_ids = self.tokenizer(prompt).input_ids state = DetokenizeState(len(input_ids)) prev_len = 0 token_ids = input_ids.copy() + if not skip_detokenize: + await self._detokenize_one_request( + DeTokenizeInput(session_id, True, input_ids)) + async for outputs in model_inst.async_stream_infer( session_id, input_ids=input_ids, @@ -124,8 +165,12 @@ async def _inference(self, req_queue: Queue, session_id: int, if n_token > prev_len: token_ids += outputs.token_ids[prev_len - n_token:] if not skip_detokenize: - _, state = self.tokenizer.detokenize_incrementally( - token_ids, state) + # _, state = self.tokenizer.detokenize_incrementally( + # token_ids, state) + _ = await self._detokenize_one_request( + DeTokenizeInput( + session_id, False, + outputs.token_ids[prev_len - n_token:])) ts.append(time.perf_counter()) ns.append(n_token) prev_len = n_token @@ -155,6 +200,30 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k, event_loop = asyncio.new_event_loop() asyncio.set_event_loop(event_loop) + process_args = ProcessArgs.init_new() + context = zmq.asyncio.Context() + self.tokenize_state: Dict[int, ProcessOutput] = {} + self.send_to_tokenizer = get_zmq_socket(context, zmq.PUSH, + process_args.to_tokenize_name) + self.recv_from_tokenizer = get_zmq_socket( + context, zmq.PULL, process_args.from_tokenize_name) + tokenize_proc = Process( + target=run_tokenize_process, + args=(self.tokenizer, process_args), + ) + tokenize_proc.start() + + self.detokenize_state: Dict[int, ProcessOutput] = {} + self.send_to_detokenizer = get_zmq_socket( + context, zmq.PUSH, process_args.to_detokenize_name) + self.recv_from_detokenizer = get_zmq_socket( + context, zmq.PULL, process_args.from_detokenize_name) + detokenize_proc = Process( + target=run_detokenize_process, + args=(self.tokenizer, process_args), + ) + detokenize_proc.start() + # start threads tasks = [] for i in range(concurrency): diff --git a/lmdeploy/serve/tokenization.py b/lmdeploy/serve/tokenization.py new file mode 100644 index 0000000000..ec459331ba --- /dev/null +++ b/lmdeploy/serve/tokenization.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union, Dict, List, Any +import tempfile +from dataclasses import dataclass +import zmq +from lmdeploy.tokenizer import Tokenizer, DetokenizeState +import asyncio + + +@dataclass +class TokenizeInput: + session_id: int + prompt: str + add_bos: bool + +@dataclass +class TokenizeOutput: + session_id: int + input_ids: List[int] + +@dataclass +class DeTokenizeInput: + session_id: int + sequence_start: bool + input_ids: List[int] + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + state: DetokenizeState = None + +@dataclass +class DeTokenizeOutput: + session_id: int + response: str + +@dataclass +class ProcessOutput: + result: Union[TokenizeOutput, DeTokenizeOutput, None] + event: asyncio.Event + + +@dataclass +class ProcessArgs: + """ipc args.""" + + to_tokenize_name: str + from_tokenize_name: str + to_detokenize_name: str + from_detokenize_name: str + + @staticmethod + def init_new(): + return ProcessArgs( + to_tokenize_name=tempfile.NamedTemporaryFile(delete=False).name, + from_tokenize_name=tempfile.NamedTemporaryFile(delete=False).name, + to_detokenize_name=tempfile.NamedTemporaryFile(delete=False).name, + from_detokenize_name=tempfile.NamedTemporaryFile(delete=False).name) + + +def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str): + socket = context.socket(socket_type) + if socket_type == zmq.PUSH: + socket.setsockopt(zmq.SNDHWM, 0) + socket.setsockopt(zmq.SNDBUF, int(0.5 * 1024**3)) + socket.connect(f"ipc://{endpoint}") + elif socket_type == zmq.PULL: + socket.setsockopt(zmq.RCVHWM, 0) + socket.setsockopt(zmq.RCVBUF, int(0.5 * 1024**3)) + socket.bind(f"ipc://{endpoint}") + else: + raise ValueError(f"Unsupported socket type: {socket_type}") + return socket + + + +class Tokenize: + + def __init__(self, tokenizer, process_args: ProcessArgs): + self.tokenizer = tokenizer + context = zmq.Context(2) + self.recv_from_engine = get_zmq_socket(context, zmq.PULL, process_args.to_tokenize_name) + self.send_to_engine = get_zmq_socket(context, zmq.PUSH, process_args.from_tokenize_name) + + def event_loop(self): + while True: + recv_obj: TokenizeInput = self.recv_from_engine.recv_pyobj() + input_ids = self.tokenizer.encode(recv_obj.prompt, add_bos=recv_obj.add_bos) + self.send_to_engine.send_pyobj(TokenizeOutput(session_id=recv_obj.session_id, input_ids=input_ids)) + + +class DeTokenize: + + def __init__(self, tokenizer, process_args: ProcessArgs): + self.tokenizer = tokenizer + context = zmq.Context() + self.recv_from_engine = get_zmq_socket(context, zmq.PULL, process_args.to_detokenize_name) + self.send_to_engine = get_zmq_socket(context, zmq.PUSH, process_args.from_detokenize_name) + self.state = {} + + def event_loop(self): + while True: + recv_obj: DeTokenizeInput = self.recv_from_engine.recv_pyobj() + if recv_obj.sequence_start: + recv_obj.state = DetokenizeState(len(recv_obj.input_ids)) + _, recv_obj.state = self.tokenizer.detokenize_incrementally( + recv_obj.input_ids, recv_obj.state) + self.state[recv_obj.session_id] = recv_obj + continue + obj: DeTokenizeInput = self.state.get(recv_obj.session_id) + obj.input_ids += recv_obj.input_ids + response, obj.state = self.tokenizer.detokenize_incrementally( + obj.input_ids, obj.state) + self.send_to_engine.send_pyobj(DeTokenizeOutput(session_id=obj.session_id, response=response)) + + +def run_tokenize_process( + tokenizer: Tokenizer, + process_args: ProcessArgs, +): + manager = Tokenize(tokenizer, process_args) + manager.event_loop() + + +def run_detokenize_process( + tokenizer: Tokenizer, + process_args: ProcessArgs, +): + manager = DeTokenize(tokenizer, process_args) + manager.event_loop() +