From 9f0c63cba95d07d80b016b7c3a1891e5d75432d5 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 29 Nov 2023 11:09:46 +0800 Subject: [PATCH] support triton client --- lmdeploy/serve/im_client.py | 71 +++++ lmdeploy/serve/turbomind/chatbot.py | 198 +++++++------ lmdeploy/serve/turbomind/im_chatbot.py | 259 ++++++++++++++++++ .../triton_models/interactive/config.pbtxt | 12 + .../triton_models/xpreprocessing/1/model.py | 179 ++++++++++++ .../triton_models/xpreprocessing/config.pbtxt | 55 ++++ lmdeploy/serve/turbomind/utils.py | 55 ++++ .../turbomind/deploy/target_model/base.py | 1 + lmdeploy/turbomind/turbomind.py | 5 +- lmdeploy/xtokenizer.py | 121 ++++++++ 10 files changed, 872 insertions(+), 84 deletions(-) create mode 100644 lmdeploy/serve/im_client.py create mode 100644 lmdeploy/serve/turbomind/im_chatbot.py create mode 100644 lmdeploy/serve/turbomind/triton_models/xpreprocessing/1/model.py create mode 100644 lmdeploy/serve/turbomind/triton_models/xpreprocessing/config.pbtxt create mode 100644 lmdeploy/xtokenizer.py diff --git a/lmdeploy/serve/im_client.py b/lmdeploy/serve/im_client.py new file mode 100644 index 0000000000..697d8d4773 --- /dev/null +++ b/lmdeploy/serve/im_client.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +from lmdeploy.serve.turbomind.im_chatbot import ImChatbot + + +def input_prompt(model_name): + """Input a prompt in the consolo interface.""" + if model_name == 'codellama': + print('\nenter !! to end the input >>>\n', end='') + sentinel = '!!' + else: + print('\ndouble enter to end input >>> ', end='') + sentinel = '' # ends when this string is seen + return '\n'.join(iter(input, sentinel)) + + +def main(tritonserver_addr: str, + session_id: int = 1, + cap: str = 'chat', + stream_output: bool = True, + **kwargs): + """An example to communicate with inference server through the command line + interface. + + Args: + tritonserver_addr (str): the address in format "ip:port" of + triton inference server + session_id (int): the identical id of a session + cap (str): the capability of a model. For example, codellama has + the ability among ['completion', 'infill', 'instruct', 'python'] + stream_output (bool): indicator for streaming output or not + **kwargs (dict): other arguments for initializing model's chat template + """ + log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING') + kwargs.update(capability=cap) + chatbot = ImChatbot(tritonserver_addr, + log_level=log_level, + display=stream_output, + **kwargs) + nth_round = 1 + while True: + prompt = input_prompt(chatbot.model_name) + if prompt == 'exit': + exit(0) + elif prompt == 'end': + chatbot.end(session_id) + else: + request_id = f'{session_id}-{nth_round}' + if stream_output: + for status, res, n_token in chatbot.stream_infer( + session_id, + prompt, + image_embs=None, + request_id=request_id, + request_output_len=512): + print(res) + else: + status, res, n_token = chatbot.infer(session_id, + prompt, + image_embs=None, + request_id=request_id, + request_output_len=512) + print(res) + nth_round += 1 + + +if __name__ == '__main__': + import fire + + fire.Fire(main) diff --git a/lmdeploy/serve/turbomind/chatbot.py b/lmdeploy/serve/turbomind/chatbot.py index 6419e53f2c..49c59d0ce2 100644 --- a/lmdeploy/serve/turbomind/chatbot.py +++ b/lmdeploy/serve/turbomind/chatbot.py @@ -29,7 +29,11 @@ class Session: sequence_length: int = 0 # the total generated token number in the session prompt: str = '' response: str = '' + response_ids: List[int] = None status: int = None # status of the session + image_embs: List[np.array] = None + image_offsets: List[int] = None + history_ids: List[int] = None class StatusCode(Enum): @@ -70,6 +74,8 @@ class Chatbot: profile_generation (bool): profile token generation or not """ + MODEL_REGISTRY = MODELS + def __init__(self, tritonserver_addr: str, model_name: str = '', @@ -81,20 +87,35 @@ def __init__(self, **model_kwargs): self.tritonserver_addr = tritonserver_addr self.model_name = model_name - if self.model_name == '': - self.model_name = self._get_model_name() - assert self.model_name in MODELS.module_dict.keys(), \ - f"'{self.model_name}' is not supported. " \ - f'The supported models are: {MODELS.module_dict.keys()}' - self.model = MODELS.get(self.model_name)(**model_kwargs) + self.ignore_eos = ignore_eos + self.log_level = log_level + self.display = display + self.profile_generation = profile_generation + self.profile_serving = profile_serving self._session = None + self._post_init(**model_kwargs) + + def _post_init(self, **model_kwargs): + self._init_prepost_processor() + self._init_cfg(**model_kwargs) + + def _init_prepost_processor(self): + tritonserver_addr = self.tritonserver_addr self.preprocess = Preprocessor(tritonserver_addr) self.postprocess = Postprocessor(tritonserver_addr) + + def _init_cfg(self, **model_kwargs): + if self.model_name == '': + self.model_name = self._get_model_name() + assert self.model_name in self.MODEL_REGISTRY.module_dict.keys(), \ + f"'{self.model_name}' is not supported. The supported models " \ + f'are: {self.MODEL_REGISTRY.module_dict.keys()}' + self.model = self.MODEL_REGISTRY.get(self.model_name)(**model_kwargs) self.bos_id = self._get_bos() self.eos_id = self._get_eos() stop_words = self._stop_words(self.model.stop_words) bad_words = None - if ignore_eos: + if self.ignore_eos: stop_words = None bad_words = np.array([[[self.eos_id], [1]]], dtype=np.int32) self.cfg = mmengine.Config( @@ -105,10 +126,6 @@ def __init__(self, repetition_penalty=self.model.repetition_penalty, stop_words=stop_words, bad_words=bad_words)) - self.log_level = log_level - self.display = display - self.profile_generation = profile_generation - self.profile_serving = profile_serving def stream_infer(self, session_id: int, @@ -147,14 +164,16 @@ def stream_infer(self, yield StatusCode.TRITON_SESSION_CLOSED, '', 0 return + self.cfg.update(**kwargs) self._session.status = 1 self._session.request_id = request_id self._session.response = '' - self.cfg.update(**kwargs) + self._session.response_ids = [] self._session.prompt = self._get_prompt(prompt, sequence_start) - for status, res, tokens in self._stream_infer(self._session, - self._session.prompt, + input_ids, _ = self.preprocess(self._session.prompt) + + for status, res, tokens in self._stream_infer(self._session, input_ids, request_output_len, sequence_start, sequence_end): @@ -200,8 +219,9 @@ def end(self, session_id: int, *args, **kwargs): return StatusCode.TRITON_SESSION_CLOSED self._session.status = 0 + input_ids = np.array([[1]], dtype=np.uint32) for status, _, _ in self._stream_infer(self._session, - prompt='', + input_ids=input_ids, request_output_len=0, sequence_start=False, sequence_end=True): @@ -241,8 +261,9 @@ def cancel(self, session_id: int, *args, **kwargs): prev_session = self._session status, res = None, None + input_ids = np.array([[]], dtype=np.uint32) for status, res, _ in self._stream_infer(self._session, - prompt='', + input_ids=input_ids, request_output_len=0, sequence_start=False, sequence_end=False, @@ -287,8 +308,10 @@ def resume(self, session_id: int, *args, **kwargs): self._session.status = 1 self._session.sequence_length = 0 histories = self._session.histories + input_ids, _ = self.preprocess(histories) + for status, _, _ in self._stream_infer(self._session, - prompt=histories, + input_ids=input_ids, request_output_len=0, sequence_start=True, sequence_end=False): @@ -339,11 +362,12 @@ def infer(self, self._session.status = 1 self._session.request_id = request_id self._session.response = '' + self._session.response_ids = [] self._session.prompt = self._get_prompt(prompt, sequence_start) + input_ids, _ = self.preprocess(self._session.prompt) status, res, tokens = None, '', 0 - for status, res, tokens in self._stream_infer(self._session, - self._session.prompt, + for status, res, tokens in self._stream_infer(self._session, input_ids, request_output_len, sequence_start, sequence_end): @@ -396,8 +420,8 @@ def _stop_words(self, stop_words: List[str]): if stop_words is None: return None assert isinstance(stop_words, List) and \ - all(isinstance(elem, str) for elem in stop_words), \ - f'stop_words must be a list but got {type(stop_words)}' + all(isinstance(elem, str) for elem in stop_words), \ + f'stop_words must be a list but got {type(stop_words)}' # each id in stop_words represents a stop word # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for # detailed explanation about turbomind's stop_words @@ -406,8 +430,8 @@ def _stop_words(self, stop_words: List[str]): for stop_word in stop_words ] assert isinstance(stop_words, List) and \ - all(isinstance(elem, int) for elem in stop_words), \ - 'invalid stop_words' + all(isinstance(elem, int) for elem in stop_words), \ + 'invalid stop_words' stop_word_offsets = range(1, len(stop_words) + 1) stop_words = np.array([[stop_words, stop_word_offsets]]).astype(np.int32) @@ -422,17 +446,18 @@ def _get_prompt(self, prompt: str, sequence_start: bool): def _stream_infer(self, session: Session, - prompt: str, + input_ids: np.array, request_output_len: int = 512, sequence_start: bool = True, sequence_end: bool = False, - cancel: bool = False): + cancel: bool = False, + **kwargs): """communicate with inference server to chat, or cancel a session, or end a session. Args: session (Session): an instance of a session - prompt (str): the concatenated prompt + input_ids (np.array): the input ids request_output_len (int): the max number of tokens to be generated sequence_start (bool): indicator for starting a sequence sequence_end (bool): indicator for ending a sequence @@ -448,9 +473,9 @@ def _stream_infer(self, f'end {sequence_end}, cancel {cancel}') assert request_output_len is None or \ - isinstance(request_output_len, int), \ - f'request_output_len is supposed to be None or int, ' \ - f'but got {type(request_output_len)}' + isinstance(request_output_len, int), \ + f'request_output_len is supposed to be None or int, ' \ + f'but got {type(request_output_len)}' if sequence_start: logger.info(f'session {session.session_id}, clear history since ' @@ -458,7 +483,7 @@ def _stream_infer(self, session.histories = '' session.sequence_length = 0 - input_ids, input_lengths = self.preprocess(prompt) + input_lengths = np.ones((1, 1), dtype=np.uint32) * input_ids.shape[-1] # got input_ids with default add_bos == True if not sequence_start and input_ids[0][0] == self.bos_id: input_ids = input_ids[:, 1:] @@ -470,7 +495,7 @@ def _stream_infer(self, input_tokens = input_lengths.squeeze() if self.profile_generation: yield StatusCode.TRITON_STREAM_ING, \ - 'ignore preprocessing during profiling generation', 0 + 'ignore preprocessing during profiling generation', 0 if request_output_len is None: request_output_len = max( 128, @@ -495,6 +520,7 @@ def _stream_infer(self, preseq_length = session.sequence_length session.response = '' + session.response_ids = [] session.status = StatusCode.TRITON_SESSION_READY que = queue.Queue() @@ -502,7 +528,8 @@ def _stream_infer(self, args=(self.tritonserver_addr, session, que, self.cfg, input_ids, input_lengths, request_output_len, sequence_start, - sequence_end, preseq_length, cancel)) + sequence_end, preseq_length, cancel), + kwargs=kwargs) producer.start() for status, res, n_token in self.stream_consumer( self.postprocess, que, session, input_tokens, preseq_length, @@ -517,10 +544,58 @@ def _stream_infer(self, f'{preseq_length}, cur seq_len {curseq_length}, ' f'diff {curseq_length - preseq_length}') - @staticmethod - def _stream_producer(tritonserver_addr, session, que, cfg, input_ids, + def _create_input(self, session, cfg, input_ids, input_lengths, + request_output_len, sequence_start, sequence_end, + preseq_length, cancel, **kwargs): + inputs = [ + prepare_tensor('input_ids', input_ids), + prepare_tensor('input_lengths', input_lengths), + prepare_tensor('request_output_len', request_output_len), + prepare_tensor('runtime_top_p', + cfg.top_p * np.ones((1, 1), dtype=np.float32)), + prepare_tensor('temperature', + cfg.temperature * np.ones( + (1, 1), dtype=np.float32)), + prepare_tensor( + 'repetition_penalty', + cfg.repetition_penalty * np.ones((1, 1), dtype=np.float32)), + prepare_tensor('step', + preseq_length * np.ones((1, 1), dtype=np.int32)) + ] + if cfg.top_k is not None: + inputs += prepare_tensor( + 'runtime_top_k', cfg.top_k * np.ones((1, 1), dtype=np.uint32)), + if cfg.stop_words is not None: + inputs += [prepare_tensor('stop_words_list', cfg.stop_words)] + if cfg.bad_words is not None: + inputs += [prepare_tensor('bad_words_list', cfg.bad_words)] + + inputs += [ + prepare_tensor( + 'session_len', + cfg.session_len * + np.ones([input_ids.shape[0], 1], dtype=np.uint32)), + prepare_tensor('START', (1 if sequence_start else 0) * np.ones( + (1, 1), dtype=np.int32)), + prepare_tensor('END', (1 if sequence_end else 0) * np.ones( + (1, 1), dtype=np.int32)), + prepare_tensor( + 'CORRID', session.session_id * np.ones( + (1, 1), dtype=np.uint64)), + prepare_tensor('STOP', (1 if cancel else 0) * np.ones( + (1, 1), dtype=np.int32)) + ] + if sequence_start: + random_seed = random.getrandbits(64) + inputs += [ + prepare_tensor('random_seed', + random_seed * np.ones((1, 1), dtype=np.uint64)) + ] + return inputs + + def _stream_producer(self, tritonserver_addr, session, que, cfg, input_ids, input_lengths, request_output_len, sequence_start, - sequence_end, preseq_length, cancel): + sequence_end, preseq_length, cancel, **kwargs): """Send a request to the triton inference server. Args: @@ -542,53 +617,10 @@ def _stream_producer(tritonserver_addr, session, que, cfg, input_ids, callback = partial(stream_callback, que) with grpcclient.InferenceServerClient(tritonserver_addr) as client: - inputs = [ - prepare_tensor('input_ids', input_ids), - prepare_tensor('input_lengths', input_lengths), - prepare_tensor('request_output_len', request_output_len), - prepare_tensor('runtime_top_p', - cfg.top_p * np.ones((1, 1), dtype=np.float32)), - prepare_tensor( - 'temperature', - cfg.temperature * np.ones((1, 1), dtype=np.float32)), - prepare_tensor( - 'repetition_penalty', - cfg.repetition_penalty * np.ones( - (1, 1), dtype=np.float32)), - prepare_tensor('step', - preseq_length * np.ones((1, 1), dtype=np.int32)) - ] - if cfg.top_k is not None: - inputs += prepare_tensor( - 'runtime_top_k', - cfg.top_k * np.ones((1, 1), dtype=np.uint32)), - if cfg.stop_words is not None: - inputs += [prepare_tensor('stop_words_list', cfg.stop_words)] - if cfg.bad_words is not None: - inputs += [prepare_tensor('bad_words_list', cfg.bad_words)] - - inputs += [ - prepare_tensor( - 'session_len', - cfg.session_len * - np.ones([input_ids.shape[0], 1], dtype=np.uint32)), - prepare_tensor('START', (1 if sequence_start else 0) * np.ones( - (1, 1), dtype=np.int32)), - prepare_tensor('END', (1 if sequence_end else 0) * np.ones( - (1, 1), dtype=np.int32)), - prepare_tensor( - 'CORRID', - session.session_id * np.ones((1, 1), dtype=np.uint64)), - prepare_tensor('STOP', (1 if cancel else 0) * np.ones( - (1, 1), dtype=np.int32)) - ] - if sequence_start: - random_seed = random.getrandbits(64) - inputs += [ - prepare_tensor( - 'random_seed', - random_seed * np.ones((1, 1), dtype=np.uint64)) - ] + inputs = self._create_input(session, cfg, input_ids, input_lengths, + request_output_len, sequence_start, + sequence_end, preseq_length, cancel, + **kwargs) client.start_stream(callback) client.async_stream_infer('turbomind', inputs, @@ -634,6 +666,7 @@ def stream_consumer(postprocess, res_queue, session, n_input_token, f'token {session.sequence_length}') session.sequence_length = preseq_length session.response = '' + session.response_ids = [] status = StatusCode.TRITON_SERVER_ERR res = f"{result['errcode']}, {result['errmsg']}" n_token = 0 @@ -675,6 +708,7 @@ def stream_consumer(postprocess, res_queue, session, n_input_token, if display: print(text, end='', flush=True) session.response += text + session.response_ids = output_ids.flatten().tolist() yield (StatusCode.TRITON_STREAM_ING, session.response, output_ids.shape[-1]) except Exception as e: diff --git a/lmdeploy/serve/turbomind/im_chatbot.py b/lmdeploy/serve/turbomind/im_chatbot.py new file mode 100644 index 0000000000..cd66ec7b29 --- /dev/null +++ b/lmdeploy/serve/turbomind/im_chatbot.py @@ -0,0 +1,259 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import numpy as np + +from lmdeploy.serve.turbomind.utils import (Postprocessor, XPreprocessor, + prepare_tensor) +from lmdeploy.xtokenizer import XTOKENIZERS + +from .chatbot import Chatbot, Session, StatusCode, filter_suffix, get_logger + + +class ImChatbot(Chatbot): + + MODEL_REGISTRY = XTOKENIZERS + + def _init_prepost_processor(self): + tritonserver_addr = self.tritonserver_addr + self.preprocess = XPreprocessor(tritonserver_addr) + self.postprocess = Postprocessor(tritonserver_addr) + + def _init_cfg(self, **model_kwargs): + super()._init_cfg(**model_kwargs) + self.img_start_id = self.model.img_start_id + self.img_end_id = self.model.img_end_id + + def stream_infer(self, + session_id: int, + prompt: str, + image_embs: List[np.array] = None, + request_id: str = '', + request_output_len: int = None, + sequence_start: bool = False, + sequence_end: bool = False, + *args, + **kwargs): + """Start a new round conversion of a session. + + Args: + session_id (int): the identical id of a session + prompt (str): user's prompt in this round conversation + image_embs (List[np.array]): image embedding features in this + round conversation + request_id (str): the identical id of this round conversation + request_output_len (int): the expected generated token numbers + sequence_start (bool): start flag of a session + sequence_end (bool): end flag of a session + Returns: + iterator: The generated content by chatbot + """ + + assert isinstance(session_id, int), \ + f'INT session id is required, but got {type(session_id)}' + + logger = get_logger(log_level=self.log_level) + logger.info(f'session {session_id}, request_id {request_id}, ' + f'request_output_len {request_output_len}') + + if self._session is None: + sequence_start = True + self._session = Session(session_id=session_id) + elif self._session.status == 0: + logger.error(f'session {session_id} has been ended. Please set ' + f'`sequence_start` be True if you want to restart it') + yield StatusCode.TRITON_SESSION_CLOSED, '', 0 + return + + self.cfg.update(**kwargs) + self._session.status = 1 + self._session.request_id = request_id + self._session.response = '' + self._session.response_ids = [] + preseq_length = self._session.sequence_length + + if self._session.history_ids is None or sequence_start: + self._session.history_ids = [] + self._session.image_embs = [] + self._session.image_offsets = [] + + self._session.prompt = prompt + input_ids, _, image_offsets = self.preprocess( + self._session.prompt, + sequence_start=sequence_start, + num_image=len(image_embs) if image_embs is not None else 0) + + for status, res, tokens in self._stream_infer( + self._session, + input_ids, + request_output_len, + sequence_start, + sequence_end, + image_embs=image_embs, + image_offsets=image_offsets): + if status == StatusCode.TRITON_STREAM_END: # remove stop_words + res = filter_suffix(res, self.model.stop_words) + if status.value < 0: + break + else: + yield status, res, tokens + if status.value == 0: + self._session.history_ids.extend(input_ids.flatten().tolist() + + self._session.response_ids) + if image_embs is not None: + self._session.image_embs.extend(image_embs) + self._session.image_offsets.extend( + (image_offsets + preseq_length).tolist()) + else: + yield status, res, tokens + + def _create_input(self, + session, + cfg, + input_ids, + input_lengths, + request_output_len, + sequence_start, + sequence_end, + preseq_length, + cancel, + image_embs=None, + image_offsets=None, + **kwargs): + inputs = super()._create_input(session, cfg, input_ids, input_lengths, + request_output_len, sequence_start, + sequence_end, preseq_length, cancel) + if image_embs is not None: + image_embs = [x.squeeze()[None] for x in image_embs] + image_embs = np.concatenate(image_embs, axis=0)[None] + image_embs = image_embs.astype(np.float16) + inputs += [ + prepare_tensor('image_embs', image_embs), + prepare_tensor('image_offsets', image_offsets), + ] + return inputs + + def resume(self, session_id: int, *args, **kwargs): + """Resume a session by sending the history conversations to triton + inference server. After resuming, users can continue chatting with + chatbot. + + Args: + session_id (int): the identical id of a session + + Returns: + int: 0: success, -1: session not found + """ + assert isinstance(session_id, int), \ + f'INT session id is required, but got {type(session_id)}' + + logger = get_logger(log_level=self.log_level) + logger.info(f'resume session: {session_id}') + + if self._session is None: + logger.error( + f"session {session_id} doesn't exist. It cannot be recovered") + return StatusCode.TRITON_SESSION_INVALID_ARG + if self._session.session_id != session_id: + logger.error( + f'you cannot resume session {session_id}, because this ' + f'session is {self._session.session_id}') + return StatusCode.TRITON_SESSION_INVALID_ARG + + self._session.status = 1 + self._session.sequence_length = 0 + input_ids = self._session.history_ids + image_embs = self._session.image_embs + image_offsets = self._session.image_offsets + + for status, _, _ in self._stream_infer(self._session, + input_ids=input_ids, + request_output_len=0, + sequence_start=True, + sequence_end=False, + image_embs=image_embs, + image_offsets=image_offsets): + if status.value < 0: + break + + return status + + def infer(self, + session_id: int, + prompt: str, + image_embs: List[np.array] = None, + request_id: str = '', + request_output_len: int = None, + sequence_start: bool = False, + sequence_end: bool = False, + *args, + **kwargs): + """Start a new round conversion of a session. Return the chat + completions in non-stream mode. + + Args: + session_id (int): the identical id of a session + prompt (str): user's prompt in this round conversation + image_embs (List[np.array]): image embedding features in this + round conversation + request_id (str): the identical id of this round conversation + request_output_len (int): the expected generated token numbers + sequence_start (bool): start flag of a session + sequence_end (bool): end flag of a session + Returns: + tuple(Status, str, int): status, text/chat completion, + generated token number + """ + assert isinstance(session_id, int), \ + f'INT session id is required, but got {type(session_id)}' + + logger = get_logger(log_level=self.log_level) + logger.info(f'session {session_id}, request_id {request_id}, ' + f'request_output_len {request_output_len}') + + if self._session is None: + sequence_start = True + self._session = Session(session_id=session_id) + elif self._session.status == 0: + logger.error(f'session {session_id} has been ended. Please set ' + f'`sequence_start` be True if you want to restart it') + return StatusCode.TRITON_SESSION_CLOSED, '', 0 + + self.cfg.update(**kwargs) + self._session.status = 1 + self._session.request_id = request_id + self._session.response = '' + self._session.response_ids = [] + preseq_length = self._session.sequence_length + + if self._session.history_ids is None or sequence_start: + self._session.history_ids = [] + self._session.image_embs = [] + self._session.image_offsets = [] + self._session.prompt = prompt + input_ids, _, image_offsets = self.preprocess( + self._session.prompt, + sequence_start=sequence_start, + num_image=len(image_embs) if image_embs is not None else 0) + status, res, tokens = None, '', 0 + for status, res, tokens in self._stream_infer( + self._session, + input_ids, + request_output_len, + sequence_start, + sequence_end, + image_embs=image_embs, + image_offsets=image_offsets): + if status.value < 0: + break + if status == StatusCode.TRITON_STREAM_END: # remove stop_words + res = filter_suffix(res, self.model.stop_words) + if status.value == 0: + self._session.history_ids.extend(input_ids.flatten().tolist() + + self._session.response_ids) + if image_embs is not None: + self._session.image_embs.extend(image_embs) + self._session.image_offsets.extend( + (image_offsets + preseq_length).tolist()) + + return status, res, tokens diff --git a/lmdeploy/serve/turbomind/triton_models/interactive/config.pbtxt b/lmdeploy/serve/turbomind/triton_models/interactive/config.pbtxt index 003881ce43..e3e62c655e 100644 --- a/lmdeploy/serve/turbomind/triton_models/interactive/config.pbtxt +++ b/lmdeploy/serve/turbomind/triton_models/interactive/config.pbtxt @@ -59,6 +59,18 @@ input [ data_type: TYPE_UINT32 dims: [ -1 ] }, + { + name: "image_embs" + data_type: TYPE_FP16 + dims: [ -1, -1, -1 ] + optional: true + }, + { + name: "image_offsets" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + }, { name: "step" data_type: TYPE_INT32 diff --git a/lmdeploy/serve/turbomind/triton_models/xpreprocessing/1/model.py b/lmdeploy/serve/turbomind/triton_models/xpreprocessing/1/model.py new file mode 100644 index 0000000000..4a269ca524 --- /dev/null +++ b/lmdeploy/serve/turbomind/triton_models/xpreprocessing/1/model.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from pathlib import Path + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils +from torch.nn.utils.rnn import pad_sequence + +# This tokenizer is `lmdeploy/turbomind/tokenizer.py`. When an LLM is served +# by triton inference server, it has to be converted first by running +# `python lmdeploy/serve/turbomind/deploy.py`. Then +# `lmdeploy/turbomind/tokenizer.py` will be copied to `tokenizer/tokenizer.py` +from .tokenizer.tokenizer import Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. + + Every Python model that is created must have "TritonPythonModel" as the + class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device + ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + self.model_config = model_config = json.loads(args['model_config']) + assert model_config['max_batch_size'] == 1 + self.model_name = self.model_config['parameters']['model_name'][ + 'string_value'] + + # Parse model output configs and convert Triton types to numpy types + input_names = ['INPUT_ID', 'REQUEST_INPUT_LEN', 'OFFSET'] + for input_name in input_names: + setattr( + self, + input_name.lower() + '_dtype', + pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name( + model_config, input_name)['data_type'])) + + from lmdeploy.xtokenizer import XTOKENIZERS + cur_folder = Path(__file__).parent + self.tokenizer = Tokenizer( + osp.join( + cur_folder, self.model_config['parameters']['tokenizer_path'] + ['string_value'])) + self.xtok = XTOKENIZERS.get(self.model_name)(self.tokenizer) + self.start_id = self.tokenizer.bos_token_id + self.end_id = self.tokenizer.eos_token_id + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + assert len(requests) == 1 + for idx, request in enumerate(requests): + # Get input tensors + query = pb_utils.get_input_tensor_by_name(request, + 'QUERY').as_numpy() + kwargs = pb_utils.get_input_tensor_by_name(request, + 'KWARGS').as_numpy() + kwargs = json.loads(kwargs[0][0]) + + if len(kwargs) == 0: + # Preprocessing input data. + input_id, request_input_len = self._create_request(query) + offset = [] + else: + input_id, offset = self._query2ids(query[0][0].decode(), + **kwargs) + request_input_len = np.ones( + (1, 1), + dtype=self.request_input_len_dtype) * input_id.shape[-1] + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + input_id_tensor = pb_utils.Tensor( + 'INPUT_ID', + np.array(input_id).astype(self.input_id_dtype)) + request_input_len_tensor = pb_utils.Tensor( + 'REQUEST_INPUT_LEN', + np.array(request_input_len).astype( + self.request_input_len_dtype)) + offset_tensor = pb_utils.Tensor( + 'OFFSET', + np.array(offset).astype(self.offset_dtype)) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse(output_tensors=[ + input_id_tensor, request_input_len_tensor, offset_tensor + ]) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + + Implementing `finalize` function is optional. This function allows the + model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') + + def _query2ids(self, query, num_image, sequence_start, **kwargs): + input_id, offset = self.xtok.query2ids((query, num_image), + sequence_start=sequence_start) + input_id = np.array(input_id)[None] + offset = np.array(offset, dtype=np.int32)[None] + return input_id, offset + + def _create_request(self, query): + """Tokenize prompts and return the token ids and their length. + + Args: + query (List[str]): a list of prompt + Returns: + tuple: token ids and their length + """ + start_ids = [] + for s in query: + _s = s[0].decode() + if _s == '': + start_id = [self.start_id + ] if self.start_id is not None else [-1] + elif _s == '': + start_id = [self.end_id] if self.end_id is not None else [-1] + else: + start_id = self.tokenizer.encode(_s) + start_ids.append(torch.IntTensor(start_id)) + + start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids]) + start_ids = pad_sequence(start_ids, + batch_first=True, + padding_value=self.end_id) + return start_ids, start_lengths diff --git a/lmdeploy/serve/turbomind/triton_models/xpreprocessing/config.pbtxt b/lmdeploy/serve/turbomind/triton_models/xpreprocessing/config.pbtxt new file mode 100644 index 0000000000..b8fd049e4d --- /dev/null +++ b/lmdeploy/serve/turbomind/triton_models/xpreprocessing/config.pbtxt @@ -0,0 +1,55 @@ +name: "xpreprocessing" +backend: "python" +max_batch_size: 1 + +input [ + { + name: "QUERY" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "KWARGS" + data_type: TYPE_STRING + dims: [ -1 ] + } +] +output [ + { + name: "INPUT_ID" + data_type: TYPE_UINT32 + dims: [ -1 ] + }, + { + name: "REQUEST_INPUT_LEN" + data_type: TYPE_UINT32 + dims: [ 1 ] + }, + { + name: "OFFSET" + data_type: TYPE_INT32 + dims: [ 1 ] + } +] + +instance_group [ + { + count: 4 + kind: KIND_CPU + } +] + +parameters { + key: "tokenizer_path" + value: { + string_value: "tokenizer/tokenizer.model" + } +} + +# information +# parameters { +# key: "model_name" +# value: { +# string_value: "qwen-vl-chat" +# } +# } diff --git a/lmdeploy/serve/turbomind/utils.py b/lmdeploy/serve/turbomind/utils.py index 802f6abaa4..005990811c 100644 --- a/lmdeploy/serve/turbomind/utils.py +++ b/lmdeploy/serve/turbomind/utils.py @@ -58,6 +58,61 @@ def infer(self, prompts: Union[str, List[str]]) -> tuple: return output0, output1 +class XPreprocessor: + """Tokenize raw prompts. + + Args: + tritonserver_addr (str): the communication address of the inference + server + """ + + def __init__(self, tritonserver_addr: str): + self.tritonserver_addr = tritonserver_addr + self.model_name = 'xpreprocessing' + + def __call__(self, *args, **kwargs): + return self.infer(*args, **kwargs) + + def infer(self, prompts: Union[str, List[str]], **kwargs) -> tuple: + """Tokenize the input prompts. + + Args: + prompts(str | List[str]): user's prompt, or a batch prompts + kwargs(dict): kwargs + + Returns: + Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token + ids, ids' length and requested output length + """ + import json + if isinstance(prompts, str): + input0 = [[prompts]] + elif isinstance(prompts, List): + input0 = [[prompt] for prompt in prompts] + else: + assert 0, f'str or List[str] prompts are expected but got ' \ + f'{type(prompts)}' + + input0_data = np.array(input0).astype(object) + input1_data = np.array([[json.dumps(kwargs)] * len(input0) + ]).astype(object) + inputs = [ + prepare_tensor('QUERY', input0_data), + prepare_tensor('KWARGS', input1_data) + ] + + with grpcclient.InferenceServerClient(self.tritonserver_addr) as \ + client: + result = client.infer(self.model_name, inputs) + output0 = result.as_numpy('INPUT_ID') + output1 = result.as_numpy('REQUEST_INPUT_LEN') + output2 = result.as_numpy('OFFSET') + if len(kwargs): + return output0, output1, output2 + else: + return output0, output1 + + class Postprocessor: """De-tokenize prompts. diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index 92e6232301..0b9cceff1d 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -58,6 +58,7 @@ class TurbomindModelConfig: max_position_embeddings: int = 0 rope_scaling_factor: float = 0.0 use_logn_attn: int = 0 + image_dim: int = 0 @classmethod def from_dict(cls, env, allow_none=False): diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 74fcaf6355..f626f2dda3 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -553,8 +553,9 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): if isinstance(image_offsets[0], int): image_offsets = [image_offsets] image_embs = [image_embs] - image_embs = [[torch.from_numpy(x).unsqueeze(0) for x in y] - for y in image_embs] + image_embs = [[ + torch.from_numpy(x).squeeze().unsqueeze(0) for x in y + ] for y in image_embs] image_embs = [torch.cat(x) for x in image_embs] image_embs = pad_sequence(image_embs, batch_first=True) image_offsets = [torch.IntTensor(x) for x in image_offsets] diff --git a/lmdeploy/xtokenizer.py b/lmdeploy/xtokenizer.py new file mode 100644 index 0000000000..7a4f5e7a07 --- /dev/null +++ b/lmdeploy/xtokenizer.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from abc import ABC, abstractmethod + +import numpy as np +from mmengine import Registry + +from .model import MODELS, SamplingParam + +XTOKENIZERS = Registry('xtokenizer', locations=['lmdeploy.xtokenizer']) + + +class BaseModel(ABC): + """Base model.""" + + def __init__(self, + session_len=2048, + top_p=0.8, + top_k=None, + temperature=0.8, + repetition_penalty=1.0, + stop_words=None, + capability='chat', + **kwargs): + self.session_len = session_len + self.top_p = top_p + self.top_k = top_k + self.temperature = temperature + self.repetition_penalty = repetition_penalty + self.stop_words = stop_words + self.capability = capability + + @abstractmethod + def query2ids(self, query, sequence_start=True, **kwargs): + """Return input ids and padding offsets.""" + + @abstractmethod + def messages2ids(self, messages, sequence_start=True, **kwargs): + """Return input ids and padding offsets. + + user content should be str or (str, int) or [str, int] + """ + + @property + def sampling_param(self): + return SamplingParam(top_p=self.top_p, + top_k=self.top_k, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty) + + +@XTOKENIZERS.register_module(name='qwen-vl') +@XTOKENIZERS.register_module(name='qwen-vl-chat') +class QwenVL(BaseModel): + """Qwen VL tokenizer.""" + + def __init__(self, + tokenizer=None, + session_len=8192, + top_p=0.3, + top_k=0, + temperature=1.0, + im_start='<|im_start|>', + im_end='<|im_end|>', + system='You are a helpful assistant.', + stop_words=['<|im_end|>'], + img_start_id=151857, + img_end_id=151858, + **kwargs): + super().__init__(**kwargs) + self.session_len = session_len + self.top_p = top_p + self.top_k = top_k + self.temperature = temperature + self.im_start = im_start + self.im_end = im_end + self.system = system + self.stop_words = stop_words + self.img_start_id = img_start_id + self.img_end_id = img_end_id + self.model = MODELS.get('qwen-7b')(im_start=im_start, + im_end=im_end, + system=system) + self.tokenizer = tokenizer + + def _construct_query(self, query): + if isinstance(query, str): + return query + query, nimg = query + text = '' + for i in range(nimg): + text += f'Picture {i + 1}:placeholder\n' + text += query + return text + + def _get_image_offsets(self, input_ids): + input_ids = np.array(input_ids) + offsets = np.where(input_ids == self.img_start_id)[0] + 1 + return offsets.tolist() + + def query2ids(self, query, sequence_start=True, **kwargs): + text = self._construct_query(query) + decorated_text = self.model.decorate_prompt( + text, sequence_start=sequence_start) + input_ids = self.tokenizer.encode(decorated_text) + offsets = self._get_image_offsets(input_ids) + return input_ids, offsets + + def messages2ids(self, messages, sequence_start=True, **kwargs): + if isinstance(messages, str) or isinstance(messages, (tuple, list)): + return self.query2ids(messages, sequence_start) + messages_cp = copy.deepcopy(messages) + for message in messages_cp: + msg_role = message['role'] + if msg_role == 'user': + message['content'] = self._construct_query(message['content']) + decorated_text = self.model.messages2prompt( + messages_cp, sequence_start=sequence_start) + input_ids = self.tokenizer.encode(decorated_text) + offsets = self._get_image_offsets(input_ids) + return input_ids, offsets