Skip to content

Commit

Permalink
wait-free
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Dec 12, 2024
1 parent 2cf49bd commit aa5573d
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 100 deletions.
58 changes: 29 additions & 29 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, model_path: str,

async def _inference(self, req_queue: Queue, session_id: int,
temperature: float, top_p: float, top_k: int,
stream_output: bool, pretokenize: bool,
stream_output: bool, skip_tokenize: bool,
skip_detokenize: bool):
model_inst = self.tm_model.create_instance()
counters = []
Expand All @@ -99,13 +99,16 @@ async def _inference(self, req_queue: Queue, session_id: int,
ts = [time.perf_counter()]
ns = [0]

if pretokenize:
if skip_tokenize:
input_ids = prompt
else:
input_ids = self.tokenizer(prompt).input_ids

state = DetokenizeState(len(input_ids))

prev_len = 0
token_ids = input_ids.copy()

async for outputs in model_inst.async_stream_infer(
session_id,
input_ids=input_ids,
Expand All @@ -117,16 +120,15 @@ async def _inference(self, req_queue: Queue, session_id: int,
sequence_start=True,
sequence_end=True,
stream_output=stream_output):
res, n_token = input_ids + outputs.token_ids, outputs.num_token
if not skip_detokenize:
_, state = self.tokenizer.detokenize_incrementally(
res, state)
# The following does not help
# await asyncio.sleep(0)
# _, state = await loop.run_in_executor(None, self.tokenizer.detokenize_incrementally, res, state)

ts.append(time.perf_counter())
ns.append(n_token)
n_token = outputs.num_token
if n_token > prev_len:
token_ids += outputs.token_ids[prev_len - n_token:]
if not skip_detokenize:
_, state = self.tokenizer.detokenize_incrementally(
token_ids, state)
ts.append(time.perf_counter())
ns.append(n_token)
prev_len = n_token

# for pytorch engine to restart a session
if isinstance(model_inst, EngineInstance):
Expand All @@ -138,41 +140,42 @@ async def _inference(self, req_queue: Queue, session_id: int,
return counters

def process_request(self, requests, concurrency, temperature, top_p, top_k,
stream_output, pretokenize, skip_detokenize):
stream_output, skip_tokenize, skip_detokenize):
req_queue = Queue()

self.pbar = tqdm(total=len(requests))

# feed request to q
for req in requests:
if pretokenize:
if skip_tokenize:
req_queue.put((self.tokenizer.encode(req[0]), *req[1:]))
else:
req_queue.put(req)
for i in range(concurrency):
req_queue.put([None, None, None])

start = time.time()

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

# start threads
tasks = []
for i in range(concurrency):
task = self._inference(req_queue, i, temperature, top_p, top_k,
stream_output, pretokenize, skip_detokenize)
stream_output, skip_tokenize,
skip_detokenize)
tasks.append(task)

async def _gather_tasks(tasks):
return await asyncio.gather(*tasks)

counters = asyncio.run(_gather_tasks(tasks))
self.pbar = tqdm(total=len(requests))

self.pbar.close()
start = time.time()

counters = asyncio.run(_gather_tasks(tasks))

elapsed_time = time.time() - start

self.pbar.close()

ttfts: List[float] = []
tpots: List[float] = []
e2es: List[float] = []
Expand All @@ -183,9 +186,6 @@ async def _gather_tasks(tasks):
total_input = 0

for ts, ns, input_len in itertools.chain.from_iterable(counters):
# print (ts)
# print (ns)
# assert 0
total_output += ns[-1]
total_input += input_len
e2es.append(ts[-1] - ts[0])
Expand All @@ -197,7 +197,7 @@ async def _gather_tasks(tasks):
t_dif = np.subtract(ts[1:], ts[:-1])
n_dif = np.subtract(ns[1:], ns[:-1])
itls.extend(t_dif[1:])
tpts.extend(n_dif[1:])
tpts.extend(n_dif)

output_throughput = total_output / elapsed_time
input_throughput = total_input / elapsed_time
Expand Down Expand Up @@ -232,8 +232,8 @@ def fmt(x):
tab_row('Total requests', len(requests))
tab_row('Concurrency', concurrency)
tab_row('Stream output', str(stream_output).lower())
tab_row('Pre-tokenization', str(pretokenize).lower())
tab_row('Skip detokenization', str(skip_detokenize).lower())
tab_row('Skip tokenize', str(skip_tokenize).lower())
tab_row('Skip detokenize', str(skip_detokenize).lower())
tab_row('Total input tokens', total_input)
tab_row('Total generated tokens', total_output)
tab_row('Input token throughput (tok/s)', input_throughput)
Expand Down Expand Up @@ -275,7 +275,7 @@ def parse_args():
parser.add_argument('--no-stream-output',
action='store_true',
help='Use stream output')
parser.add_argument('--pre-tokenize',
parser.add_argument('--skip-tokenize',
action='store_true',
help='Pre-tokenize input prompts before starting')
parser.add_argument('--skip-detokenize',
Expand Down Expand Up @@ -366,7 +366,7 @@ def main():
top_k=args.top_k,
concurrency=args.concurrency,
stream_output=not args.no_stream_output,
pretokenize=args.pre_tokenize,
skip_tokenize=args.skip_tokenize,
skip_detokenize=args.skip_detokenize)


Expand Down
13 changes: 9 additions & 4 deletions lmdeploy/turbomind/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def infer(generator, session_id, input_ids, gen_config, sequence_start,
async def async_infer(generator, session_id, input_ids, gen_config,
sequence_start, sequence_end, step, stream_output,
tokenizer, state):
token_ids = input_ids.copy()
prev_len = 0
async for output in generator.async_stream_infer(
session_id=session_id,
input_ids=input_ids,
Expand All @@ -56,10 +58,13 @@ async def async_infer(generator, session_id, input_ids, gen_config,
sequence_end=sequence_end,
step=step,
stream_output=stream_output):
res, tokens = input_ids + output.token_ids, output.num_token
# decode res
response, state = tokenizer.detokenize_incrementally(res, state=state)
print(response, end='', flush=True)
tokens = output.num_token
if tokens > prev_len:
token_ids += output.token_ids[prev_len - tokens:]
response, state = tokenizer.detokenize_incrementally(token_ids,
state=state)
prev_len = tokens
print(response, end='', flush=True)
return tokens


Expand Down
42 changes: 10 additions & 32 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,6 @@ def __init__(self,
self.session_len = self.config.session_len
self.eos_id = self.tokenizer.eos_token_id

self.pending_num = 0
self.pending_cond = asyncio.Condition()

def _create_weight(self, model_comm):
"""Allocate weight buffer, load params if from_workspace."""

Expand Down Expand Up @@ -510,22 +507,15 @@ def prepare_inputs(self,

return inputs, input_lengths

async def async_signal(self, state):
async def async_signal(self):
async with self.cond:
self.flag, self.state = 1, state
self.flag = 1
self.cond.notify()
# self.flag, self.state = 1, state
# self.state_ready.set()

def async_signal_cb(self, state):
coro = self.async_signal(state)
def async_signal_cb(self):
coro = self.async_signal()
asyncio.run_coroutine_threadsafe(coro, self.event_loop)

def add_pending(self, n: int = 1):
# self.tm_model.pending_event.clear()
# self.tm_model.pending_num += n
self.tm_model.pending_num += n

async def async_stream_infer(self,
session_id,
input_ids,
Expand Down Expand Up @@ -577,41 +567,29 @@ async def async_stream_infer(self,

inputs = _np_dict_to_tm_dict(inputs)

outputs = self.model_inst.forward(inputs, session, gen_cfg,
stream_output, self.async_signal_cb)
outputs, shared_state = self.model_inst.forward(
inputs, session, gen_cfg, stream_output, self.async_signal_cb)

outputs = _tm_dict_to_torch_dict(outputs)

output_ids_buf = outputs['output_ids']

# seq_start = step + input_length[0]

out_logprobs = None
finish = False

# async with self.tm_model.pending_cond:
# self.tm_model.pending_num -= 1
# if self.tm_model.pending_num == 0:
# self.tm_model.pending_cond.notify_all()

# self.tm_model.pending_num -= 1
# if self.tm_model.pending_num == 0:
# self.tm_model.pending_event.set()
output_ids = []
output_len = 0
prev_len = step + input_length[0]
try:
# generator
while True:
# async with self.tm_model.pending_cond:
# while self.tm_model.pending_num > 0:
# await self.tm_model.pending_cond.wait()
# await self.tm_model.pending_event.wait()

async with self.cond:
while not self.flag:
await self.cond.wait()
state, self.flag = self.state, 0
self.flag = 0

state = shared_state.consume()

status, seq_len = state.status, state.seq_len

Expand All @@ -625,7 +603,7 @@ async def async_stream_infer(self,
if seq_len == prev_len and not finish:
continue

output_ids += output_ids_buf[prev_len:seq_len]
output_ids += output_ids_buf[prev_len:seq_len].tolist()
output_len += seq_len - prev_len

self.model_inst.report_tokens_per_tick(seq_len - prev_len)
Expand Down
19 changes: 12 additions & 7 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1406,13 +1406,15 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
}
else if (r->stream_output && rank_ == 0) {
const auto seq_len = r->outputs.getVal<int>("sequence_length");
const auto v = r->flag->load(std::memory_order_relaxed) < 1;
if (v) {
r->flag->fetch_add(1, std::memory_order_relaxed);
if (true) {
// Create signals by copying the request handles for non-finished streaming requests
signals.push_back([this, r, seq_len] {
try {
r->forward_cb({Request::kOk, seq_len});
auto new_state = new RequestState{Request::kOk, seq_len};
auto old_state = r->state->exchange(new_state);
if (!old_state) {
r->forward_cb();
}
}
catch (const std::bad_function_call& e) {
TM_LOG_ERROR("Null stream callback for (%s)", std::to_string(r->id).c_str());
Expand Down Expand Up @@ -1494,11 +1496,14 @@ auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Sig

auto ec = std::exchange(state_->errors[index], Request::kOk);

const auto len = state_->requests[index]->outputs.getVal<int>("sequence_length");
// move the request handle into the signal
return [this, ec, r = std::move(state_->requests[index])] {
return [this, ec, len, r = std::move(state_->requests[index])] {
if (rank_ == 0) {
if (r->forward_cb) {
r->forward_cb({Request::kFinish, r->outputs.getVal<int>("sequence_length")});
auto new_state = new RequestState{Request::kFinish, len};
auto old_state = r->state->exchange(new_state);
if (!old_state) {
r->forward_cb();
}
}
};
Expand Down
9 changes: 2 additions & 7 deletions src/turbomind/models/llama/Request.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,9 @@ struct Request {
std::function<void(int)> cancel_cb;
std::function<void(int)> end_cb;

std::function<void(RequestState)> forward_cb;
std::function<void()> forward_cb;

// std::atomic_flag* flag;
std::atomic<int>* flag;

std::atomic<int>* seq_len;
std::shared_ptr<AtomicRequestState> state;

enum {
kOk = 0,
Expand All @@ -115,8 +112,6 @@ struct Request {
kTooLong = 6, // history + prompt > session_len,
kFinish = 7,
};

// std::promise<int> signal;
};

class RequestQueue {
Expand Down
24 changes: 12 additions & 12 deletions src/turbomind/python/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,12 @@ PYBIND11_MODULE(_turbomind, m)
return oss.str();
});

py::class_<ft::RequestState, std::shared_ptr<ft::RequestState>>(m, "RequestState")
py::class_<ft::RequestState, std::unique_ptr<ft::RequestState>>(m, "RequestState")
.def_readonly("status", &ft::RequestState::status)
.def_readonly("seq_len", &ft::RequestState::seq_len);

// py::class_<ft::AtomicRequestState, std::shared_ptr<ft::AtomicRequestState>>(m, "AtomicRequestState")
// .def("load", [](ft::AtomicRequestState& s) { return s.load(); });
py::class_<ft::AtomicRequestState, std::shared_ptr<ft::AtomicRequestState>>(m, "AtomicRequestState")
.def("consume", [](ft::AtomicRequestState& s) { return s.exchange(nullptr); });

// data type
py::enum_<ft::DataType>(m, "DataType")
Expand Down Expand Up @@ -451,26 +451,26 @@ PYBIND11_MODULE(_turbomind, m)
py::class_<ModelRequest>(m, "ModelRequest")
.def(
"forward",
[](ModelRequest* model_request,
std::shared_ptr<TensorMap> input_tensors,
const ft::SessionParam& session,
const ft::GenerationConfig& gen_cfg,
bool stream_output,
std::function<void(ft::RequestState)> cb) {
[](ModelRequest* model_request,
std::shared_ptr<TensorMap> input_tensors,
const ft::SessionParam& session,
const ft::GenerationConfig& gen_cfg,
bool stream_output,
std::function<void()> cb) {
ModelRequest::InputParam param{};
param.tensors = std::move(input_tensors);
param.session = session;
param.gen_cfg = gen_cfg;
param.stream_output = stream_output;
auto ret = model_request->Forward(std::move(param), [cb = std::move(cb)](ft::RequestState s) {
auto ret = model_request->Forward(std::move(param), [cb = std::move(cb)]() {
try {
cb(s);
cb();
}
catch (const py::error_already_set& e) {
std::cerr << e.what() << std::endl;
}
});
return ret.tensors;
return std::make_tuple(std::move(ret.tensors), std::move(ret.state));
},
py::call_guard<py::gil_scoped_release>(),
"input_tensors"_a,
Expand Down
Loading

0 comments on commit aa5573d

Please sign in to comment.