Skip to content

Commit

Permalink
tok-1
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 18, 2024
1 parent aa5573d commit 1c30724
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 4 deletions.
77 changes: 73 additions & 4 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,24 @@
import os
import random
import time
from multiprocessing import Process
from queue import Queue
from typing import List, Tuple, Union
from typing import Dict, List, Tuple, Union

import numpy as np
import zmq.asyncio
from tqdm import tqdm

from lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter
from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
TurbomindEngineConfig)
from lmdeploy.pytorch.engine import EngineInstance
from lmdeploy.serve.tokenization import (DeTokenizeInput, DeTokenizeOutput,
ProcessArgs, ProcessOutput,
TokenizeInput, TokenizeOutput,
get_zmq_socket,
run_detokenize_process,
run_tokenize_process)
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
from lmdeploy.utils import get_logger

Expand Down Expand Up @@ -87,6 +95,33 @@ def __init__(self, model_path: str,
self.csv = csv
self.pbar = None

async def _tokenize_one_request(self, input: TokenizeInput):
self.tokenize_state[input.session_id] = ProcessOutput(
None, asyncio.Event())
await self.send_to_tokenizer.send_pyobj(input)
while not self.tokenize_state[input.session_id].event.is_set():
await self._tokenize_step()
return self.tokenize_state[input.session_id].result.input_ids

async def _tokenize_step(self):
result: TokenizeOutput = await self.recv_from_tokenizer.recv_pyobj()
self.tokenize_state[result.session_id].result = result
self.tokenize_state[result.session_id].event.set()

async def _detokenize_one_request(self, input: DeTokenizeInput):
self.detokenize_state[input.session_id] = ProcessOutput(
None, asyncio.Event())
await self.send_to_detokenizer.send_pyobj(input)
if not input.sequence_start:
while not self.detokenize_state[input.session_id].event.is_set():
await self._detokenize_step()
return self.detokenize_state[input.session_id].result.response

async def _detokenize_step(self):
result: DeTokenizeOutput = await self.recv_from_detokenizer.recv_pyobj()
self.detokenize_state[result.session_id].result = result
self.detokenize_state[result.session_id].event.set()

async def _inference(self, req_queue: Queue, session_id: int,
temperature: float, top_p: float, top_k: int,
stream_output: bool, skip_tokenize: bool,
Expand All @@ -102,13 +137,19 @@ async def _inference(self, req_queue: Queue, session_id: int,
if skip_tokenize:
input_ids = prompt
else:
input_ids = self.tokenizer(prompt).input_ids
input_ids = await self._tokenize_one_request(
TokenizeInput(session_id, prompt, True))
# input_ids = self.tokenizer(prompt).input_ids

state = DetokenizeState(len(input_ids))

prev_len = 0
token_ids = input_ids.copy()

if not skip_detokenize:
await self._detokenize_one_request(
DeTokenizeInput(session_id, True, input_ids))

async for outputs in model_inst.async_stream_infer(
session_id,
input_ids=input_ids,
Expand All @@ -124,8 +165,12 @@ async def _inference(self, req_queue: Queue, session_id: int,
if n_token > prev_len:
token_ids += outputs.token_ids[prev_len - n_token:]
if not skip_detokenize:
_, state = self.tokenizer.detokenize_incrementally(
token_ids, state)
# _, state = self.tokenizer.detokenize_incrementally(
# token_ids, state)
_ = await self._detokenize_one_request(
DeTokenizeInput(
session_id, False,
outputs.token_ids[prev_len - n_token:]))
ts.append(time.perf_counter())
ns.append(n_token)
prev_len = n_token
Expand Down Expand Up @@ -155,6 +200,30 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k,
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

process_args = ProcessArgs.init_new()
context = zmq.asyncio.Context()
self.tokenize_state: Dict[int, ProcessOutput] = {}
self.send_to_tokenizer = get_zmq_socket(context, zmq.PUSH,
process_args.to_tokenize_name)
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, process_args.from_tokenize_name)
tokenize_proc = Process(
target=run_tokenize_process,
args=(self.tokenizer, process_args),
)
tokenize_proc.start()

self.detokenize_state: Dict[int, ProcessOutput] = {}
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, process_args.to_detokenize_name)
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, process_args.from_detokenize_name)
detokenize_proc = Process(
target=run_detokenize_process,
args=(self.tokenizer, process_args),
)
detokenize_proc.start()

# start threads
tasks = []
for i in range(concurrency):
Expand Down
129 changes: 129 additions & 0 deletions lmdeploy/serve/tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union, Dict, List, Any
import tempfile
from dataclasses import dataclass
import zmq
from lmdeploy.tokenizer import Tokenizer, DetokenizeState
import asyncio


@dataclass
class TokenizeInput:
session_id: int
prompt: str
add_bos: bool

@dataclass
class TokenizeOutput:
session_id: int
input_ids: List[int]

@dataclass
class DeTokenizeInput:
session_id: int
sequence_start: bool
input_ids: List[int]
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
state: DetokenizeState = None

@dataclass
class DeTokenizeOutput:
session_id: int
response: str

@dataclass
class ProcessOutput:
result: Union[TokenizeOutput, DeTokenizeOutput, None]
event: asyncio.Event


@dataclass
class ProcessArgs:
"""ipc args."""

to_tokenize_name: str
from_tokenize_name: str
to_detokenize_name: str
from_detokenize_name: str

@staticmethod
def init_new():
return ProcessArgs(
to_tokenize_name=tempfile.NamedTemporaryFile(delete=False).name,
from_tokenize_name=tempfile.NamedTemporaryFile(delete=False).name,
to_detokenize_name=tempfile.NamedTemporaryFile(delete=False).name,
from_detokenize_name=tempfile.NamedTemporaryFile(delete=False).name)


def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
socket = context.socket(socket_type)
if socket_type == zmq.PUSH:
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, int(0.5 * 1024**3))
socket.connect(f"ipc://{endpoint}")
elif socket_type == zmq.PULL:
socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, int(0.5 * 1024**3))
socket.bind(f"ipc://{endpoint}")
else:
raise ValueError(f"Unsupported socket type: {socket_type}")
return socket



class Tokenize:

def __init__(self, tokenizer, process_args: ProcessArgs):
self.tokenizer = tokenizer
context = zmq.Context(2)
self.recv_from_engine = get_zmq_socket(context, zmq.PULL, process_args.to_tokenize_name)
self.send_to_engine = get_zmq_socket(context, zmq.PUSH, process_args.from_tokenize_name)

def event_loop(self):
while True:
recv_obj: TokenizeInput = self.recv_from_engine.recv_pyobj()
input_ids = self.tokenizer.encode(recv_obj.prompt, add_bos=recv_obj.add_bos)
self.send_to_engine.send_pyobj(TokenizeOutput(session_id=recv_obj.session_id, input_ids=input_ids))


class DeTokenize:

def __init__(self, tokenizer, process_args: ProcessArgs):
self.tokenizer = tokenizer
context = zmq.Context()
self.recv_from_engine = get_zmq_socket(context, zmq.PULL, process_args.to_detokenize_name)
self.send_to_engine = get_zmq_socket(context, zmq.PUSH, process_args.from_detokenize_name)
self.state = {}

def event_loop(self):
while True:
recv_obj: DeTokenizeInput = self.recv_from_engine.recv_pyobj()
if recv_obj.sequence_start:
recv_obj.state = DetokenizeState(len(recv_obj.input_ids))
_, recv_obj.state = self.tokenizer.detokenize_incrementally(
recv_obj.input_ids, recv_obj.state)
self.state[recv_obj.session_id] = recv_obj
continue
obj: DeTokenizeInput = self.state.get(recv_obj.session_id)
obj.input_ids += recv_obj.input_ids
response, obj.state = self.tokenizer.detokenize_incrementally(
obj.input_ids, obj.state)
self.send_to_engine.send_pyobj(DeTokenizeOutput(session_id=obj.session_id, response=response))


def run_tokenize_process(
tokenizer: Tokenizer,
process_args: ProcessArgs,
):
manager = Tokenize(tokenizer, process_args)
manager.event_loop()


def run_detokenize_process(
tokenizer: Tokenizer,
process_args: ProcessArgs,
):
manager = DeTokenize(tokenizer, process_args)
manager.event_loop()

0 comments on commit 1c30724

Please sign in to comment.