Skip to content

Commit

Permalink
update perf metrics & adaptive tokens per tick
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Dec 11, 2024
1 parent 9c56be8 commit 2cf49bd
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 144 deletions.
231 changes: 124 additions & 107 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import asyncio
import csv
import itertools
import json
import os
import random
Expand Down Expand Up @@ -86,20 +87,23 @@ def __init__(self, model_path: str,
self.csv = csv
self.pbar = None

async def _inference(self, req_queue: Queue, res_queue: Queue,
session_id: int, temperature: float, top_p: float,
top_k: int, stream_output: bool):
async def _inference(self, req_queue: Queue, session_id: int,
temperature: float, top_p: float, top_k: int,
stream_output: bool, pretokenize: bool,
skip_detokenize: bool):
model_inst = self.tm_model.create_instance()
stats = []
# get each generated token's latency
per_token_latency_stats = []
counters = []
for prompt, input_seqlen, output_seqlen in iter(
req_queue.get_nowait, [None, None, None]):
_per_token_latency_stats = [0] * (output_seqlen + 1)
prev = time.perf_counter()
n_prev_token = 0

input_ids = self.tokenizer(prompt).input_ids
ts = [time.perf_counter()]
ns = [0]

if pretokenize:
input_ids = prompt
else:
input_ids = self.tokenizer(prompt).input_ids

state = DetokenizeState(len(input_ids))

async for outputs in model_inst.async_stream_infer(
Expand All @@ -114,43 +118,37 @@ async def _inference(self, req_queue: Queue, res_queue: Queue,
sequence_end=True,
stream_output=stream_output):
res, n_token = input_ids + outputs.token_ids, outputs.num_token
_, state = self.tokenizer.detokenize_incrementally(res, state)
now = time.perf_counter()
if n_prev_token != n_token:
_per_token_latency_stats[n_prev_token] = np.round(
now - prev, 3)
n_prev_token = n_token
prev = now
if not skip_detokenize:
_, state = self.tokenizer.detokenize_incrementally(
res, state)
# The following does not help
# await asyncio.sleep(0)
# _, state = await loop.run_in_executor(None, self.tokenizer.detokenize_incrementally, res, state)

ts.append(time.perf_counter())
ns.append(n_token)

# for pytorch engine to restart a session
if isinstance(model_inst, EngineInstance):
await model_inst.async_end(session_id)
assert output_seqlen <= n_token <= output_seqlen + 1, \
f'Error. session_id({session_id}) request {output_seqlen} ' \
f'tokens, but generate {n_token} tokens.\n' \
f'prompt: {prompt}'

first_token_latency = _per_token_latency_stats[0]
completion_tokens = n_token
total_tokens = n_token + input_seqlen
stats.append([
first_token_latency, completion_tokens, output_seqlen,
total_tokens
])
# skip the first token latency
per_token_latency_stats.append(_per_token_latency_stats[1:])

counters.append((ts, ns, input_seqlen))
self.pbar.update(1)
res_queue.put_nowait((session_id, stats, per_token_latency_stats))

return counters

def process_request(self, requests, concurrency, temperature, top_p, top_k,
stream_output):
res_queue = Queue()
stream_output, pretokenize, skip_detokenize):
req_queue = Queue()

self.pbar = tqdm(total=len(requests))

# feed request to q
for req in requests:
req_queue.put(req)
if pretokenize:
req_queue.put((self.tokenizer.encode(req[0]), *req[1:]))
else:
req_queue.put(req)
for i in range(concurrency):
req_queue.put([None, None, None])

Expand All @@ -162,87 +160,95 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k,
# start threads
tasks = []
for i in range(concurrency):
task = self._inference(req_queue, res_queue, i, temperature, top_p,
top_k, stream_output)
task = self._inference(req_queue, i, temperature, top_p, top_k,
stream_output, pretokenize, skip_detokenize)
tasks.append(task)

async def _gather_tasks(tasks):
return await asyncio.gather(*tasks)

event_loop.run_until_complete(_gather_tasks(tasks))
counters = asyncio.run(_gather_tasks(tasks))

self.pbar.close()

elapsed_time = time.time() - start

stats = []
per_token_latency_stats = []
while not res_queue.empty():
session_id, _stats, _per_token_latency_stats = res_queue.get()
stats.append(np.array(_stats))
per_token_latency_stats += [
item for sublist in _per_token_latency_stats
for item in sublist
]
stats = np.concatenate(stats).reshape(-1, 4)

first_token_latency_min = np.min(stats[:, 0], axis=0)
first_token_latency_max = np.max(stats[:, 0], axis=0)
first_token_latency_ave = np.mean(stats[:, 0], axis=0)
completion_tokens = np.sum(stats[:, 1], axis=0)
total_tokens = np.sum(stats[:, 3], axis=0)
prompt_tokens = total_tokens - completion_tokens
completion_token_throughput = completion_tokens / elapsed_time
total_token_throughput = total_tokens / elapsed_time
ttfts: List[float] = []
tpots: List[float] = []
e2es: List[float] = []
itls: List[float] = []
tpts: List[int] = []

total_output = 0
total_input = 0

for ts, ns, input_len in itertools.chain.from_iterable(counters):
# print (ts)
# print (ns)
# assert 0
total_output += ns[-1]
total_input += input_len
e2es.append(ts[-1] - ts[0])
ttfts.append(ts[1] - ts[0])
if ns[-1] > ns[1]:
tpots.append((ts[-1] - ts[1]) / (ns[-1] - ns[1]))
else: # no-stream-output
tpots.append((ts[-1] - ts[0]) / (ns[-1] - ns[0]))
t_dif = np.subtract(ts[1:], ts[:-1])
n_dif = np.subtract(ns[1:], ns[:-1])
itls.extend(t_dif[1:])
tpts.extend(n_dif[1:])

output_throughput = total_output / elapsed_time
input_throughput = total_input / elapsed_time

qs = (50, 75, 90, 99)

tpot_ms_mean = np.mean(tpots)
tpot_ms_stat = tuple(np.percentile(tpots, qs))
e2e_mean = np.mean(e2es)
e2e_stat = tuple(np.percentile(e2es, qs))

if stream_output:
ttft_ms_mean = np.mean(ttfts)
ttft_ms_stat = tuple(np.percentile(ttfts, qs))
itls_ms_mean = np.mean(itls)
itls_ms_stat = tuple(np.percentile(itls, qs))
tpts_ms_mean = np.mean(tpts)
tpts_ms_stat = tuple(np.percentile(tpts, qs).astype(int))

rps = len(requests) / elapsed_time
rpm = rps * 60

per_token_latency_stats.sort()
percentiles = [
np.round(
per_token_latency_stats[int(percent *
len(per_token_latency_stats))], 3)
for percent in [0.5, 0.75, 0.95, 0.99]
]

print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
f'elapsed_time: {elapsed_time:.3f}s\n')

def tab_row(name, *items):

def fmt(x):
return '{:>10.3f}'.format(x) if isinstance(
x, float) else '{:>10}'.format(x)

print('{:<35}{}'.format(name, ''.join([fmt(x) for x in items])))

print('\n{s:{c}^{n}}'.format(s=' Profile Throughtput ', n=85, c='='))
tab_row('Benchmark duration', elapsed_time)
tab_row('Total requests', len(requests))
tab_row('Concurrency', concurrency)
tab_row('Stream output', str(stream_output).lower())
tab_row('Pre-tokenization', str(pretokenize).lower())
tab_row('Skip detokenization', str(skip_detokenize).lower())
tab_row('Total input tokens', total_input)
tab_row('Total generated tokens', total_output)
tab_row('Input token throughput (tok/s)', input_throughput)
tab_row('Output token throughput (tok/s)', output_throughput)
tab_row('Request throughput (req/s)', rps)
print('-' * 85)
tab_row('', 'mean', *(f'P{q}' for q in qs))
tab_row('End-to-end Latency', e2e_mean, *e2e_stat)
if stream_output:
tab_row('Time to First Token (TTFT)', ttft_ms_mean, *ttft_ms_stat)
tab_row('Time per Output Token (TPOT)', tpot_ms_mean, *tpot_ms_stat)
if stream_output:
print(f'first token latency(s)(min, max, ave): '
f'{first_token_latency_min:.3f}, '
f'{first_token_latency_max:.3f}, '
f'{first_token_latency_ave:.3f}')
print(f'per-token latency(s) percentile(50, 75, 95, 99): '
f'{percentiles}\n')
print(
f'number of prompt tokens: {prompt_tokens:.0f}\n'
f'number of completion tokens: {completion_tokens:.0f}\n'
f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa
f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa
f'RPS (request per second): {rps:.3f} req/s\n'
f'RPM (request per minute): {rpm:.3f} req/min\n'
f'{"-" * 50}\n')

if self.csv:
with open(self.csv, 'w') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([
'batch', 'num_promts', 'RPS', 'RPM', 'FTL(ave)(s)',
'FTL(min)(s)', 'FTL(max)(s)', '50%(s)', '75%(s)', '95%(s)',
'99%(s)', 'throughput(out tok/s)',
'throughput(total tok/s)'
])
writer.writerow([
concurrency,
len(requests), f'{rps:.3f}', f'{rpm:.3f}',
f'{first_token_latency_ave:.3f}' if stream_output else '-',
f'{first_token_latency_min:.3f}' if stream_output else '-',
f'{first_token_latency_max:.3f}' if stream_output else '-',
f'{percentiles[0]:.3f}' if stream_output else '-',
f'{percentiles[1]:.3f}' if stream_output else '-',
f'{percentiles[2]:.3f}' if stream_output else '-',
f'{percentiles[3]:.3f}' if stream_output else '-',
f'{completion_token_throughput:.3f}',
f'{total_token_throughput:.3f}'
])
tab_row('Inter-token Latency (ITL)', itls_ms_mean, *itls_ms_stat)
tab_row('Tokens per Tick', tpts_ms_mean, *tpts_ms_stat)
print('=' * 85)


def parse_args():
Expand All @@ -266,6 +272,15 @@ def parse_args():
type=int,
help='Number of prompts to process',
default=5000)
parser.add_argument('--no-stream-output',
action='store_true',
help='Use stream output')
parser.add_argument('--pre-tokenize',
action='store_true',
help='Pre-tokenize input prompts before starting')
parser.add_argument('--skip-detokenize',
action='store_true',
help='Skip detokenizing output tokens')
parser.add_argument('--csv',
type=str,
help='Where to save the result.',
Expand Down Expand Up @@ -350,7 +365,9 @@ def main():
top_p=args.top_p,
top_k=args.top_k,
concurrency=args.concurrency,
stream_output=True)
stream_output=not args.no_stream_output,
pretokenize=args.pre_tokenize,
skip_detokenize=args.skip_detokenize)


if __name__ == '__main__':
Expand Down
43 changes: 38 additions & 5 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def __init__(self,
self.session_len = self.config.session_len
self.eos_id = self.tokenizer.eos_token_id

self.pending_num = 0
self.pending_cond = asyncio.Condition()

def _create_weight(self, model_comm):
"""Allocate weight buffer, load params if from_workspace."""

Expand Down Expand Up @@ -511,11 +514,18 @@ async def async_signal(self, state):
async with self.cond:
self.flag, self.state = 1, state
self.cond.notify()
# self.flag, self.state = 1, state
# self.state_ready.set()

def async_signal_cb(self, state):
coro = self.async_signal(state)
asyncio.run_coroutine_threadsafe(coro, self.event_loop)

def add_pending(self, n: int = 1):
# self.tm_model.pending_event.clear()
# self.tm_model.pending_num += n
self.tm_model.pending_num += n

async def async_stream_infer(self,
session_id,
input_ids,
Expand Down Expand Up @@ -548,6 +558,7 @@ async def async_stream_infer(self,

self.event_loop = asyncio.get_running_loop()
self.cond = asyncio.Condition()
# self.state_ready = asyncio.Event()
self.flag = 0

gen_cfg = self._get_generation_config(gen_config)
Expand All @@ -573,14 +584,30 @@ async def async_stream_infer(self,

output_ids_buf = outputs['output_ids']

seq_start = step + input_length[0]
# seq_start = step + input_length[0]

out_logprobs = None
finish = False

# async with self.tm_model.pending_cond:
# self.tm_model.pending_num -= 1
# if self.tm_model.pending_num == 0:
# self.tm_model.pending_cond.notify_all()

# self.tm_model.pending_num -= 1
# if self.tm_model.pending_num == 0:
# self.tm_model.pending_event.set()
output_ids = []
output_len = 0
prev_len = step + input_length[0]
try:
# generator
while True:
# async with self.tm_model.pending_cond:
# while self.tm_model.pending_num > 0:
# await self.tm_model.pending_cond.wait()
# await self.tm_model.pending_event.wait()

async with self.cond:
while not self.flag:
await self.cond.wait()
Expand All @@ -595,14 +622,20 @@ async def async_stream_infer(self,
yield self._get_error_output()
break

if seq_start == seq_len and not finish:
if seq_len == prev_len and not finish:
continue

output_ids = output_ids_buf[seq_start:seq_len]
gen_len = seq_len - seq_start
output_ids += output_ids_buf[prev_len:seq_len]
output_len += seq_len - prev_len

self.model_inst.report_tokens_per_tick(seq_len - prev_len)

status = ResponseType.FINISH if finish else ResponseType.SUCCESS
output = EngineOutput(status, output_ids.tolist(), gen_len,
output = EngineOutput(status, output_ids, output_len.item(),
out_logprobs)

prev_len = seq_len

yield output

if finish:
Expand Down
Loading

0 comments on commit 2cf49bd

Please sign in to comment.