diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f7a6d475c..eaf310bb0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -220,8 +220,10 @@ struct Args { /// Limits the number of tokens for the prefill operation. /// Since this operation take the most memory and is compute bound, it is interesting /// to limit the number of requests that can be sent. - #[clap(default_value = "4096", long, env)] - max_batch_prefill_tokens: u32, + /// The default value will be set based on the max_input_length, since it cannot be less than + /// that value. + #[clap(long, env)] + max_batch_prefill_tokens: Option, /// **IMPORTANT** This is one critical control to allow maximum usage /// of the available hardware. @@ -1181,6 +1183,13 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:?}", args); + // Set default values dervided from other args + + // If the value of max_batch_prefill_tokens is not specified, default to max_input_length + if args.max_batch_prefill_tokens.is_none() { + args.max_batch_prefill_tokens = Option(args.max_input_length as u32) + } + // Validate args if args.max_input_length >= args.max_total_tokens { return Err(LauncherError::ArgumentValidation( @@ -1212,7 +1221,7 @@ fn main() -> Result<(), LauncherError> { } if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { - if args.max_batch_prefill_tokens > *max_batch_total_tokens { + if args.max_batch_prefill_tokens.unwrap() > *max_batch_total_tokens { return Err(LauncherError::ArgumentValidation(format!( "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", args.max_batch_prefill_tokens, max_batch_total_tokens