From eeb3bd13025d3673a645a229f21e68993fe72be6 Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Mon, 16 Dec 2024 14:30:09 -0800 Subject: [PATCH] rely on vllm argparser to validating user configs --- .../properties_manager/vllm_rb_properties.py | 131 +++++++----------- .../rolling_batch/vllm_rolling_batch.py | 2 +- .../tests/test_properties_manager.py | 21 ++- 3 files changed, 68 insertions(+), 86 deletions(-) diff --git a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py index ab8b64fd4..28c4218ca 100644 --- a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py +++ b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py @@ -10,15 +10,17 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. -import ast -import dataclasses +import argparse +from dataclasses import asdict from typing import Optional, Any, Mapping, Tuple, Dict - from pydantic import field_validator, model_validator, ConfigDict from vllm import EngineArgs +from vllm.utils import FlexibleArgumentParser from djl_python.properties_manager.properties import Properties +DEFAULT_ENGINE_ARGS = asdict(EngineArgs()) + DTYPE_MAPPER = { "fp32": "float32", "fp16": "float16", @@ -38,8 +40,6 @@ class VllmRbProperties(Properties): # The following configs have different defaults, or additional processing in DJL compared to vLLM dtype: str = "auto" max_loras: Optional[int] = 4 - long_lora_scaling_factors: Optional[Tuple[float, ...]] = None - limit_mm_per_prompt: Optional[Mapping[str, int]] = None # Neuron vLLM properties device: Optional[str] = None @@ -56,44 +56,6 @@ def validate_engine(cls, engine): f"Need python engine to start vLLM RollingBatcher") return engine - @field_validator('long_lora_scaling_factors', mode='before') - def validate_long_lora_scaling_factors(cls, val): - if isinstance(val, str): - val = ast.literal_eval(val) - if not isinstance(val, tuple): - if isinstance(val, list): - val = tuple(float(v) for v in val) - elif isinstance(val, float): - val = (val, ) - elif isinstance(val, int): - val = (float(val), ) - else: - raise ValueError( - "long_lora_scaling_factors must be convertible to a tuple of floats." - ) - return val - - @field_validator('limit_mm_per_prompt', mode="before") - def validate_limit_mm_per_prompt(cls, val) -> Mapping[str, int]: - out_dict: Dict[str, int] = {} - for item in val.split(","): - kv_parts = [part.lower().strip() for part in item.split("=")] - if len(kv_parts) != 2: - raise ValueError("Each item should be in the form key=value") - key, value = kv_parts - - try: - parsed_value = int(value) - except ValueError as e: - raise ValueError( - f"Failed to parse value of item {key}={value}") from e - - if key in out_dict and out_dict[key] != parsed_value: - raise ValueError( - f"Conflicting values specified for key: {key}") - out_dict[key] = parsed_value - return out_dict - @model_validator(mode='after') def validate_pipeline_parallel(self): if self.pipeline_parallel_degree != 1: @@ -139,46 +101,59 @@ def djl_config_conflicts_with_vllm_config(lmi_config_name, raise ValueError( "Both the DJL cpu_offload_gb_per_gpu and vllm cpu_offload_gb configs have been set with conflicting values." "Only set one of these configurations") + if djl_config_conflicts_with_vllm_config("max_rolling_batch_size", + "max_num_seqs"): + raise ValueError( + "Both the DJL max_rolling_batch_size and vllm max_num_seqs configs have been set with conflicting values." + "Only set the DJL max_rolling_batch_size config") + + def generate_vllm_engine_arg_dict(self, + passthrough_vllm_engine_args) -> dict: + # We use the full set of engine args here in order for the EngineArgs.from_cli_args call to work since + # it requires all engine args at least be present. + # TODO: We may want to upstream a change to vllm here to make this a bit nicer for us + vllm_engine_args = DEFAULT_ENGINE_ARGS.copy() + # For the following configs, we only accept the LMI name currently (may be same as vllm config name) + vllm_engine_args['model'] = self.model_id_or_path + vllm_engine_args['tensor_parallel_size'] = self.tensor_parallel_degree + vllm_engine_args[ + 'pipeline_parallel_size'] = self.pipeline_parallel_degree + vllm_engine_args['max_num_seqs'] = self.max_rolling_batch_size + vllm_engine_args['dtype'] = DTYPE_MAPPER[self.dtype] + vllm_engine_args['trust_remote_code'] = self.trust_remote_code + vllm_engine_args['revision'] = self.revision + vllm_engine_args['max_loras'] = self.max_loras + vllm_engine_args['limit_mm_per_prompt'] = self.limit_mm_per_prompt + vllm_engine_args[ + 'long_lora_scaling_factors'] = self.long_lora_scaling_factors + # For these configs, either the LMI or vllm name is ok + vllm_engine_args['quantization'] = passthrough_vllm_engine_args.pop( + 'quantization', self.quantize) + vllm_engine_args[ + 'max_num_batched_tokens'] = passthrough_vllm_engine_args.pop( + 'max_num_batched_tokens', + self.max_rolling_batch_prefill_tokens) + vllm_engine_args['cpu_offload_gb'] = passthrough_vllm_engine_args.pop( + 'cpu_offload_gb', self.cpu_offload_gb_per_gpu) + # Neuron specific configs + if self.device == 'neuron': + vllm_engine_args['device'] = self.device + vllm_engine_args['preloaded_model'] = self.preloaded_model + vllm_engine_args['generation_config'] = self.generation_config + vllm_engine_args.update(passthrough_vllm_engine_args) + return vllm_engine_args def get_engine_args(self) -> EngineArgs: additional_vllm_engine_args = self.get_additional_vllm_engine_args() self.handle_lmi_vllm_config_conflicts(additional_vllm_engine_args) - max_model_len = additional_vllm_engine_args.pop("max_model_len", None) - if self.device == 'neuron': - return EngineArgs( - model=self.model_id_or_path, - preloaded_model=self.preloaded_model, - tensor_parallel_size=self.tensor_parallel_degree, - pipeline_parallel_size=self.pipeline_parallel_degree, - dtype=DTYPE_MAPPER[self.dtype], - max_num_seqs=self.max_rolling_batch_size, - block_size=max_model_len, - max_model_len=max_model_len, - trust_remote_code=self.trust_remote_code, - revision=self.revision, - device=self.device, - generation_config=self.generation_config, - **additional_vllm_engine_args, - ) - return EngineArgs( - model=self.model_id_or_path, - tensor_parallel_size=self.tensor_parallel_degree, - pipeline_parallel_size=self.pipeline_parallel_degree, - dtype=DTYPE_MAPPER[self.dtype], - max_model_len=max_model_len, - quantization=self.quantize, - max_num_batched_tokens=self.max_rolling_batch_prefill_tokens, - max_loras=self.max_loras, - long_lora_scaling_factors=self.long_lora_scaling_factors, - cpu_offload_gb=self.cpu_offload_gb_per_gpu, - limit_mm_per_prompt=self.limit_mm_per_prompt, - **additional_vllm_engine_args, - ) + vllm_engine_arg_dict = self.generate_vllm_engine_arg_dict( + additional_vllm_engine_args) + namespace = argparse.Namespace(**vllm_engine_arg_dict) + return EngineArgs.from_cli_args(namespace) def get_additional_vllm_engine_args(self) -> Dict[str, Any]: - all_engine_args = EngineArgs.__annotations__ return { - arg: val - for arg, val in self.__pydantic_extra__.items() - if arg in all_engine_args + k: v + for k, v in self.__pydantic_extra__.items() + if k in DEFAULT_ENGINE_ARGS } diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 25ee2da4b..d49b753bd 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -52,7 +52,7 @@ def __init__(self, model_id_or_path: str, properties: dict, self.request_cache = OrderedDict() self.lora_id_counter = AtomicCounter(0) self.lora_requests = {} - self.is_mistral_tokenizer = self.vllm_configs.tokenizer_mode == 'mistral' + self.is_mistral_tokenizer = args.tokenizer_mode == 'mistral' def get_tokenizer(self): return self.engine.tokenizer.tokenizer diff --git a/engines/python/setup/djl_python/tests/test_properties_manager.py b/engines/python/setup/djl_python/tests/test_properties_manager.py index 262b0c676..ed77bf591 100644 --- a/engines/python/setup/djl_python/tests/test_properties_manager.py +++ b/engines/python/setup/djl_python/tests/test_properties_manager.py @@ -442,32 +442,39 @@ def test_vllm_valid(properties): def test_long_lora_scaling_factors(properties): properties['long_lora_scaling_factors'] = "3.0" vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, )) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, )) properties['long_lora_scaling_factors'] = "3" vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, )) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, )) properties['long_lora_scaling_factors'] = "3.0,4.0" vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, 4.0)) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, 4.0)) properties['long_lora_scaling_factors'] = "3.0, 4.0 " vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, 4.0)) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, 4.0)) properties['long_lora_scaling_factors'] = "(3.0,)" vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, )) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, )) properties['long_lora_scaling_factors'] = "(3.0,4.0)" vllm_props = VllmRbProperties(**properties) - self.assertEqual(vllm_props.long_lora_scaling_factors, (3.0, 4.0)) + engine_args = vllm_props.get_engine_args() + self.assertEqual(engine_args.long_lora_scaling_factors, (3.0, 4.0)) def test_invalid_long_lora_scaling_factors(properties): properties['long_lora_scaling_factors'] = "a,b" + vllm_props = VllmRbProperties(**properties) with self.assertRaises(ValueError): - VllmRbProperties(**properties) + vllm_props.get_engine_args() properties = { 'model_id': 'sample_model_id',