From af689934c8e845c94c846838a58ae61b475b6c5f Mon Sep 17 00:00:00 2001 From: Noah Yoshida Date: Mon, 11 Mar 2024 14:27:03 -0700 Subject: [PATCH 1/2] use default value derived from max input tokensgst --- launcher/src/main.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f7a6d475c..98c6f9c50 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -220,8 +220,10 @@ struct Args { /// Limits the number of tokens for the prefill operation. /// Since this operation take the most memory and is compute bound, it is interesting /// to limit the number of requests that can be sent. - #[clap(default_value = "4096", long, env)] - max_batch_prefill_tokens: u32, + /// The default value will be set based on the max_input_length, since it cannot be less than + /// that value. + #[clap(long, env)] + max_batch_prefill_tokens: Option, /// **IMPORTANT** This is one critical control to allow maximum usage /// of the available hardware. @@ -1211,7 +1213,7 @@ fn main() -> Result<(), LauncherError> { tracing::info!("Sharding model on {num_shard} processes"); } - if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { + if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens.unwrap_or_else(args.max_batch_prefill_tokens) { if args.max_batch_prefill_tokens > *max_batch_total_tokens { return Err(LauncherError::ArgumentValidation(format!( "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", From 4faf2f09f051aa6668bfe5b527c647436e51bb50 Mon Sep 17 00:00:00 2001 From: Noah Yoshida Date: Mon, 11 Mar 2024 14:27:21 -0700 Subject: [PATCH 2/2] using the value of max input tokens to derive default max batch prefill tokens --- launcher/src/main.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 98c6f9c50..eaf310bb0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1183,6 +1183,13 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:?}", args); + // Set default values dervided from other args + + // If the value of max_batch_prefill_tokens is not specified, default to max_input_length + if args.max_batch_prefill_tokens.is_none() { + args.max_batch_prefill_tokens = Option(args.max_input_length as u32) + } + // Validate args if args.max_input_length >= args.max_total_tokens { return Err(LauncherError::ArgumentValidation( @@ -1213,8 +1220,8 @@ fn main() -> Result<(), LauncherError> { tracing::info!("Sharding model on {num_shard} processes"); } - if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens.unwrap_or_else(args.max_batch_prefill_tokens) { - if args.max_batch_prefill_tokens > *max_batch_total_tokens { + if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { + if args.max_batch_prefill_tokens.unwrap() > *max_batch_total_tokens { return Err(LauncherError::ArgumentValidation(format!( "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", args.max_batch_prefill_tokens, max_batch_total_tokens