Skip to content

Commit

Permalink
rely on vllm argparser to validating user configs
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Dec 20, 2024
1 parent 7ac8965 commit 1ecdbcf
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ class Properties(BaseModel):
input_formatter: Optional[Callable] = None
waiting_steps: Optional[int] = None
mpi_mode: bool = False
tgi_compat: Optional[bool] = False
bedrock_compat: Optional[bool] = False
enable_lora: Optional[bool] = False
tgi_compat: bool = False
bedrock_compat: bool = False
enable_lora: bool = False

# Spec_dec
draft_model_id: Optional[str] = None
Expand Down
200 changes: 95 additions & 105 deletions engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import ast
import dataclasses
from typing import Optional, Any, Mapping, Tuple, Dict

from pydantic import field_validator, model_validator, ConfigDict
import logging
from typing import Optional, Any, Dict, Tuple
from pydantic import field_validator, model_validator, ConfigDict, Field
from vllm import EngineArgs
from vllm.utils import FlexibleArgumentParser
from vllm.engine.arg_utils import StoreBoolean

from djl_python.properties_manager.properties import Properties

Expand All @@ -29,17 +30,19 @@

class VllmRbProperties(Properties):
engine: Optional[str] = None
# The following configs have different names in DJL compared to vLLM
quantize: Optional[str] = None
# The following configs have different names in DJL compared to vLLM, we only accept DJL name currently
tensor_parallel_degree: int = 1
pipeline_parallel_degree: int = 1
max_rolling_batch_prefill_tokens: Optional[int] = None
cpu_offload_gb_per_gpu: Optional[int] = 0
# The following configs have different names in DJL compared to vLLM, either is accepted
quantize: Optional[str] = Field(alias="quantization", default=None)
max_rolling_batch_prefill_tokens: Optional[int] = Field(
alias="max_num_batched_tokens", default=None)
cpu_offload_gb_per_gpu: Optional[float] = Field(alias="cpu_offload_gb",
default=None)
# The following configs have different defaults, or additional processing in DJL compared to vLLM
dtype: str = "auto"
max_loras: Optional[int] = 4
max_loras: int = 4
long_lora_scaling_factors: Optional[Tuple[float, ...]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None

# Neuron vLLM properties
device: Optional[str] = None
Expand All @@ -56,7 +59,17 @@ def validate_engine(cls, engine):
f"Need python engine to start vLLM RollingBatcher")
return engine

@model_validator(mode='after')
def validate_pipeline_parallel(self):
if self.pipeline_parallel_degree != 1:
raise ValueError(
"Pipeline parallelism is not supported in vLLM's LLMEngine used in rolling_batch implementation"
)
return self

@field_validator('long_lora_scaling_factors', mode='before')
# TODO: processing of this field is broken in vllm via from_cli_args
# we should upstream a fix for this to vllm
def validate_long_lora_scaling_factors(cls, val):
if isinstance(val, str):
val = ast.literal_eval(val)
Expand All @@ -73,112 +86,89 @@ def validate_long_lora_scaling_factors(cls, val):
)
return val

@field_validator('limit_mm_per_prompt', mode="before")
def validate_limit_mm_per_prompt(cls, val) -> Mapping[str, int]:
out_dict: Dict[str, int] = {}
for item in val.split(","):
kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2:
raise ValueError("Each item should be in the form key=value")
key, value = kv_parts

try:
parsed_value = int(value)
except ValueError as e:
raise ValueError(
f"Failed to parse value of item {key}={value}") from e

if key in out_dict and out_dict[key] != parsed_value:
raise ValueError(
f"Conflicting values specified for key: {key}")
out_dict[key] = parsed_value
return out_dict

@model_validator(mode='after')
def validate_pipeline_parallel(self):
if self.pipeline_parallel_degree != 1:
raise ValueError(
"Pipeline parallelism is not supported in vLLM's LLMEngine used in rolling_batch implementation"
)
return self

def handle_lmi_vllm_config_conflicts(self, additional_vllm_engine_args):

def djl_config_conflicts_with_vllm_config(lmi_config_name,
vllm_config_name) -> bool:
# TODO: We may be able to refactor this to throw the ValueError directly from this method.
# The errors are slightly different depending on the specific configs, so for now we keep
# the exception separate in favor of better, more specific client errors
def validate_potential_lmi_vllm_config_conflict(
lmi_config_name, vllm_config_name):
lmi_config_val = self.__getattribute__(lmi_config_name)
vllm_config_val = additional_vllm_engine_args.get(vllm_config_name)
if vllm_config_val is not None and lmi_config_val is not None:
return lmi_config_val != vllm_config_val
return False

if djl_config_conflicts_with_vllm_config("quantize", "quantization"):
raise ValueError(
"Both the DJL quantize config, and vllm quantization configs have been set with conflicting values."
"Only set the DJL quantize config")
if djl_config_conflicts_with_vllm_config("tensor_parallel_degree",
"tensor_parallel_size"):
raise ValueError(
"Both the DJL tensor_parallel_degree and vllm tensor_parallel_size configs have been set with conflicting values."
"Only set the DJL tensor_parallel_degree config")
if djl_config_conflicts_with_vllm_config("pipeline_parallel_degree",
"pipeline_parallel_size"):
raise ValueError(
"Both the DJL pipeline_parallel_degree and vllm pipeline_parallel_size configs have been set with conflicting values."
"Only set the DJL pipeline_parallel_degree config")
if djl_config_conflicts_with_vllm_config(
"max_rolling_batch_prefill_tokens", "max_num_batched_tokens"):
raise ValueError(
"Both the DJL max_rolling_batch_prefill_tokens and vllm max_num_batched_tokens configs have been set with conflicting values."
"Only set one of these configurations")
if djl_config_conflicts_with_vllm_config("cpu_offload_gb_per_gpu",
"cpu_offload_gb"):
raise ValueError(
"Both the DJL cpu_offload_gb_per_gpu and vllm cpu_offload_gb configs have been set with conflicting values."
"Only set one of these configurations")
if vllm_config_val != lmi_config_val:
raise ValueError(
f"Both the DJL {lmi_config_val}={lmi_config_val} and vLLM {vllm_config_name}={vllm_config_val} configs have been set with conflicting values."
f"We currently only accept the DJL config {lmi_config_val}, please remove the vllm {vllm_config_name} configuration."
)

validate_potential_lmi_vllm_config_conflict("tensor_parallel_degree",
"tensor_parallel_size")
validate_potential_lmi_vllm_config_conflict("pipeline_parallel_degree",
"pipeline_parallel_size")
validate_potential_lmi_vllm_config_conflict("max_rolling_batch_size",
"max_num_seqs")

def generate_vllm_engine_arg_dict(self,
passthrough_vllm_engine_args) -> dict:
vllm_engine_args = {
'model': self.model_id_or_path,
'tensor_parallel_size': self.tensor_parallel_degree,
'pipeline_parallel_size': self.pipeline_parallel_degree,
'max_num_seqs': self.max_rolling_batch_size,
'dtype': DTYPE_MAPPER[self.dtype],
'revision': self.revision,
'max_loras': self.max_loras,
'enable_lora': self.enable_lora,
'long_lora_scaling_factors': self.long_lora_scaling_factors,
}
if self.quantize is not None:
vllm_engine_args['quantization'] = self.quantize
if self.max_rolling_batch_prefill_tokens is not None:
vllm_engine_args[
'max_num_batched_tokens'] = self.max_rolling_batch_prefill_tokens
if self.cpu_offload_gb_per_gpu is not None:
vllm_engine_args['cpu_offload_gb'] = self.cpu_offload_gb_per_gpu
if self.device is not None:
vllm_engine_args['device'] = self.device
if self.preloaded_model is not None:
vllm_engine_args['preloaded_model'] = self.preloaded_model
if self.generation_config is not None:
vllm_engine_args['generation_config'] = self.generation_config
vllm_engine_args.update(passthrough_vllm_engine_args)
return vllm_engine_args

def get_engine_args(self) -> EngineArgs:
additional_vllm_engine_args = self.get_additional_vllm_engine_args()
self.handle_lmi_vllm_config_conflicts(additional_vllm_engine_args)
max_model_len = additional_vllm_engine_args.pop("max_model_len", None)
if self.device == 'neuron':
return EngineArgs(
model=self.model_id_or_path,
preloaded_model=self.preloaded_model,
tensor_parallel_size=self.tensor_parallel_degree,
pipeline_parallel_size=self.pipeline_parallel_degree,
dtype=DTYPE_MAPPER[self.dtype],
max_num_seqs=self.max_rolling_batch_size,
block_size=max_model_len,
max_model_len=max_model_len,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
device=self.device,
generation_config=self.generation_config,
**additional_vllm_engine_args,
)
return EngineArgs(
model=self.model_id_or_path,
tensor_parallel_size=self.tensor_parallel_degree,
pipeline_parallel_size=self.pipeline_parallel_degree,
dtype=DTYPE_MAPPER[self.dtype],
max_model_len=max_model_len,
quantization=self.quantize,
max_num_batched_tokens=self.max_rolling_batch_prefill_tokens,
max_loras=self.max_loras,
long_lora_scaling_factors=self.long_lora_scaling_factors,
cpu_offload_gb=self.cpu_offload_gb_per_gpu,
limit_mm_per_prompt=self.limit_mm_per_prompt,
**additional_vllm_engine_args,
vllm_engine_arg_dict = self.generate_vllm_engine_arg_dict(
additional_vllm_engine_args)
logging.debug(
f"Construction vLLM engine args from the following DJL configs: {vllm_engine_arg_dict}"
)
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args_list = self.construct_vllm_args_list(vllm_engine_arg_dict, parser)
args = parser.parse_args(args=args_list)
return EngineArgs.from_cli_args(args)

def get_additional_vllm_engine_args(self) -> Dict[str, Any]:
all_engine_args = EngineArgs.__annotations__
return {
arg: val
for arg, val in self.__pydantic_extra__.items()
if arg in all_engine_args
k: v
for k, v in self.__pydantic_extra__.items()
if k in EngineArgs.__annotations__
}

def construct_vllm_args_list(self, vllm_engine_args: dict,
parser: FlexibleArgumentParser):
# Modified from https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/utils.py#L1258
args_list = []
store_boolean_arguments = {
action.dest
for action in parser._actions if isinstance(action, StoreBoolean)
}
for engine_arg, engine_arg_value in vllm_engine_args.items():
if str(engine_arg_value).lower() in {
'true', 'false'
} and engine_arg not in store_boolean_arguments:
if str(engine_arg_value).lower() == 'true':
args_list.append(f"--{engine_arg}")
else:
args_list.append(f"--{engine_arg}={engine_arg_value}")
return args_list
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.request_cache = OrderedDict()
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = {}
self.is_mistral_tokenizer = self.vllm_configs.tokenizer_mode == 'mistral'
self.is_mistral_tokenizer = args.tokenizer_mode == 'mistral'

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
Expand Down
Loading

0 comments on commit 1ecdbcf

Please sign in to comment.