Skip to content

Commit

Permalink
rely on vllm argparser to validating user configs
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Dec 17, 2024
1 parent 916607b commit 8e7a3ad
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import ast
import dataclasses
import argparse
from dataclasses import asdict
from typing import Optional, Any, Mapping, Tuple, Dict

from pydantic import field_validator, model_validator, ConfigDict
from vllm import EngineArgs
from vllm.utils import FlexibleArgumentParser

from djl_python.properties_manager.properties import Properties

DEFAULT_ENGINE_ARGS = asdict(EngineArgs())

DTYPE_MAPPER = {
"fp32": "float32",
"fp16": "float16",
Expand Down Expand Up @@ -56,44 +58,6 @@ def validate_engine(cls, engine):
f"Need python engine to start vLLM RollingBatcher")
return engine

@field_validator('long_lora_scaling_factors', mode='before')
def validate_long_lora_scaling_factors(cls, val):
if isinstance(val, str):
val = ast.literal_eval(val)
if not isinstance(val, tuple):
if isinstance(val, list):
val = tuple(float(v) for v in val)
elif isinstance(val, float):
val = (val, )
elif isinstance(val, int):
val = (float(val), )
else:
raise ValueError(
"long_lora_scaling_factors must be convertible to a tuple of floats."
)
return val

@field_validator('limit_mm_per_prompt', mode="before")
def validate_limit_mm_per_prompt(cls, val) -> Mapping[str, int]:
out_dict: Dict[str, int] = {}
for item in val.split(","):
kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2:
raise ValueError("Each item should be in the form key=value")
key, value = kv_parts

try:
parsed_value = int(value)
except ValueError as e:
raise ValueError(
f"Failed to parse value of item {key}={value}") from e

if key in out_dict and out_dict[key] != parsed_value:
raise ValueError(
f"Conflicting values specified for key: {key}")
out_dict[key] = parsed_value
return out_dict

@model_validator(mode='after')
def validate_pipeline_parallel(self):
if self.pipeline_parallel_degree != 1:
Expand Down Expand Up @@ -139,46 +103,59 @@ def djl_config_conflicts_with_vllm_config(lmi_config_name,
raise ValueError(
"Both the DJL cpu_offload_gb_per_gpu and vllm cpu_offload_gb configs have been set with conflicting values."
"Only set one of these configurations")
if djl_config_conflicts_with_vllm_config("max_rolling_batch_size",
"max_num_seqs"):
raise ValueError(
"Both the DJL max_rolling_batch_size and vllm max_num_seqs configs have been set with conflicting values."
"Only set the DJL max_rolling_batch_size config")

def generate_vllm_engine_arg_dict(self,
passthrough_vllm_engine_args) -> dict:
# We use the full set of engine args here in order for the EngineArgs.from_cli_args call to work since
# it requires all engine args at least be present.
# TODO: We may want to upstream a change to vllm here to make this a bit nicer for us
vllm_engine_args = DEFAULT_ENGINE_ARGS.copy()
# For the following configs, we only accept the LMI name currently (may be same as vllm config name)
vllm_engine_args['model'] = self.model_id_or_path
vllm_engine_args['tensor_parallel_size'] = self.tensor_parallel_degree
vllm_engine_args[
'pipeline_parallel_size'] = self.pipeline_parallel_degree
vllm_engine_args['max_num_seqs'] = self.max_rolling_batch_size
vllm_engine_args['dtype'] = DTYPE_MAPPER[self.dtype]
vllm_engine_args['trust_remote_code'] = self.trust_remote_code
vllm_engine_args['revision'] = self.revision
vllm_engine_args['max_loras'] = self.max_loras
vllm_engine_args['limit_mm_per_prompt'] = self.limit_mm_per_prompt
vllm_engine_args[
'long_lora_scaling_factors'] = self.long_lora_scaling_factors
# For these configs, either the LMI or vllm name is ok
vllm_engine_args['quantization'] = passthrough_vllm_engine_args.pop(
'quantization', self.quantize)
vllm_engine_args[
'max_num_batched_tokens'] = passthrough_vllm_engine_args.pop(
'max_num_batched_tokens',
self.max_rolling_batch_prefill_tokens)
vllm_engine_args['cpu_offload_gb'] = passthrough_vllm_engine_args.pop(
'cpu_offload_gb', self.cpu_offload_gb_per_gpu)
# Neuron specific configs
if self.device == 'neuron':
vllm_engine_args['device'] = self.device
vllm_engine_args['preloaded_model'] = self.preloaded_model
vllm_engine_args['generation_config'] = self.generation_config
vllm_engine_args.update(passthrough_vllm_engine_args)
return vllm_engine_args

def get_engine_args(self) -> EngineArgs:
additional_vllm_engine_args = self.get_additional_vllm_engine_args()
self.handle_lmi_vllm_config_conflicts(additional_vllm_engine_args)
max_model_len = additional_vllm_engine_args.pop("max_model_len", None)
if self.device == 'neuron':
return EngineArgs(
model=self.model_id_or_path,
preloaded_model=self.preloaded_model,
tensor_parallel_size=self.tensor_parallel_degree,
pipeline_parallel_size=self.pipeline_parallel_degree,
dtype=DTYPE_MAPPER[self.dtype],
max_num_seqs=self.max_rolling_batch_size,
block_size=max_model_len,
max_model_len=max_model_len,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
device=self.device,
generation_config=self.generation_config,
**additional_vllm_engine_args,
)
return EngineArgs(
model=self.model_id_or_path,
tensor_parallel_size=self.tensor_parallel_degree,
pipeline_parallel_size=self.pipeline_parallel_degree,
dtype=DTYPE_MAPPER[self.dtype],
max_model_len=max_model_len,
quantization=self.quantize,
max_num_batched_tokens=self.max_rolling_batch_prefill_tokens,
max_loras=self.max_loras,
long_lora_scaling_factors=self.long_lora_scaling_factors,
cpu_offload_gb=self.cpu_offload_gb_per_gpu,
limit_mm_per_prompt=self.limit_mm_per_prompt,
**additional_vllm_engine_args,
)
vllm_engine_arg_dict = self.generate_vllm_engine_arg_dict(
additional_vllm_engine_args)
namespace = argparse.Namespace(**vllm_engine_arg_dict)
return EngineArgs.from_cli_args(namespace)

def get_additional_vllm_engine_args(self) -> Dict[str, Any]:
all_engine_args = EngineArgs.__annotations__
return {
arg: val
for arg, val in self.__pydantic_extra__.items()
if arg in all_engine_args
k: v
for k, v in self.__pydantic_extra__.items()
if k in DEFAULT_ENGINE_ARGS
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.request_cache = OrderedDict()
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = {}
self.is_mistral_tokenizer = self.vllm_configs.tokenizer_mode == 'mistral'
self.is_mistral_tokenizer = args.tokenizer_mode == 'mistral'

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
Expand Down

0 comments on commit 8e7a3ad

Please sign in to comment.