Skip to content

Commit

Permalink
release pytorch engine
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Oct 10, 2024
1 parent 1585809 commit 4d35004
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 6 deletions.
17 changes: 17 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(self,
self.req_manager = self._bind_request_manager()

# create main thread
self._stop_flag = False
self._start_loop()
self._create_buffers()
self.engine_instance = self.create_instance()
Expand Down Expand Up @@ -241,6 +242,18 @@ def _bind_request_manager(self):
req_manager.bind_func(RequestType.ADD_MESSAGE, self._on_add_message)
return req_manager

def close(self):
self._stop_flag = True
self.req_manager.close()
self.model_agent.close()
self.model_agent = None
self._seq_length_buf = None
self._inputs = None
torch._C._cuda_clearCublasWorkspaces()
torch.cuda.empty_cache()
import gc
gc.collect()

def _start_loop(self):
"""start loop."""
return self.req_manager.start_loop(self.async_loop)
Expand Down Expand Up @@ -930,6 +943,10 @@ async def __step():
out_que.task_done()

while True:
if self._stop_flag:
logger.info('Stop _async_loop')
loop_background.cancel()
break
if self.req_manager.has_requests():
self.req_manager.step()

Expand Down
45 changes: 39 additions & 6 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import asyncio
import atexit
import os
import sys
from datetime import timedelta
from functools import partial
from typing import Any, Callable, Dict, List
from weakref import ReferenceType, ref

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -192,6 +195,10 @@ def get_logits(self, hidden_states: torch.Tensor):
"""get logits of model output."""
raise NotImplementedError('Not implemented.')

def close(self):
"""release model."""
pass


class BaseModelAgent(AutoModelAgent):
"""Base model agent.
Expand Down Expand Up @@ -235,6 +242,9 @@ def __init__(self,

self.stream = torch.cuda.Stream()

def close(self):
del self.patched_model

def _build_model(self,
model_path: str,
adapters: Dict[str, str] = None,
Expand Down Expand Up @@ -540,10 +550,11 @@ def __init__(self,
trust_remote_code: bool = True) -> None:
import signal

def __signal_term_handler(sig, frame):
def __signal_term_handler(sig, frame, agent):
"""sigterm handler."""
if hasattr(self, 'mp_context'):
procs = self.mp_context.processes
agent = agent()
if hasattr(agent, 'mp_context'):
procs = agent.mp_context.processes
for p in procs:
if p.is_alive():
p.kill()
Expand All @@ -553,8 +564,10 @@ def __signal_term_handler(sig, frame):

super().__init__(model_config=model_config, cache_config=cache_config)

signal.signal(signal.SIGTERM, __signal_term_handler)
signal.signal(signal.SIGTERM,
partial(__signal_term_handler, agent=ref(self)))

self.old_sys_excepthook = sys.excepthook
self.mp_ctx = mp.get_context('spawn')
self.world_size = world_size
self.backend_config = backend_config
Expand All @@ -579,6 +592,22 @@ def __signal_term_handler(sig, frame):
self.cache_config = cache_config
self.cache_engine = cache_engine
self.stream = torch.cuda.Stream()
self.stop = False

def close(self):
_exit_by_sending_exit_flag(0, ref(self))
self.stop = True
procs: List[mp.Process] = self.mp_context.processes
for p in procs:
if p.is_alive():
logger.info(f'Terminate {p}')
p.terminate()
else:
logger.info(f'Close {p}')
p.close()
if dist.is_initialized():
dist.destroy_process_group()
sys.excepthook = self.old_sys_excepthook

def _start_sub_process(self, model_path: str, model_config: ModelConfig,
cache_config: CacheConfig,
Expand Down Expand Up @@ -627,7 +656,7 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig,
dist.destroy_process_group()
raise e
# Please see Note [Exit By Sending Exit Flag]
atexit.register(_exit_by_sending_exit_flag, rank, self)
atexit.register(_exit_by_sending_exit_flag, rank, ref(self))

@torch.inference_mode()
def _build_model(
Expand Down Expand Up @@ -715,10 +744,14 @@ def get_logits(self, hidden_states: torch.Tensor):
return self.patched_model.get_logits(hidden_states)


def _exit_by_sending_exit_flag(rank: int, agent: TPModelAgent):
def _exit_by_sending_exit_flag(rank: int,
agent: 'ReferenceType[TPModelAgent]'):
"""[Note] Exit By Sending Exit Flag: the registration to `atexit` of this
function should be called after importing torch.multiprocessing and the
initialization of distributed process group."""
agent = agent()
if agent is None or getattr(agent, 'stop', False):
return
if not hasattr(agent, 'stream'):
# agent is not initialized, just exits normally
if hasattr(agent, 'patched_model'):
Expand Down
17 changes: 17 additions & 0 deletions lmdeploy/pytorch/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,23 @@ def __init__(self, thread_safe: bool = False):
if thread_safe:
self.thread_requests = Queue()

def close(self):
if not self._thread_safe:
if self._loop_task is not None:
_run_until_complete(self._loop_task)
else:
loop = self.event_loop
tasks = asyncio.all_tasks(loop=loop)

async def cancel_tasks():
for task in tasks:
task.cancel()

f = asyncio.run_coroutine_threadsafe(cancel_tasks(), loop=loop)
f.result()
loop.call_soon_threadsafe(loop.stop)
self._loop_thread.join()

def create_loop_task(self):
"""create coro task."""
logger.debug('creating engine loop task.')
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ def __init__(self,
self._session_id = count(0)
self.request_logger = RequestLogger(max_log_len)

def close(self):
if hasattr(self, 'engine'):
if isinstance(self.backend_config, PytorchEngineConfig):
self.engine.close()
del self.engine

def _build_turbomind(
self,
model_path: str,
Expand Down

0 comments on commit 4d35004

Please sign in to comment.