diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e45b83a3..978af49f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -342,9 +342,18 @@ struct Args { /// Whether you want to compile the model into a CUDA graph. /// This will speed up decoding but increase GPU memory usage. - #[clap(long, env, value_enum)] + /// Only use either `--compile` or `--eager`. Using both at the same time will + /// result in an error. + #[clap(default_value = "true", long, env, value_enum)] compile: bool, + /// Whether you want to run the model in eager mode, without + /// CUDA mode compilation, or run it with compilation. + /// Only use either `--compile` or `--eager`. Using both at the same time will + /// result in an error. + #[clap(default_value = "false", long, env, value_enum)] + eager: bool, + // The maximum batch size past which CUDA graphs are disabled. #[clap(default_value = "128", long, env)] compile_max_batch_size: usize, @@ -656,6 +665,7 @@ fn shard_manager( adapter_source: String, quantize: Option, compile: bool, + eager: bool, compile_max_batch_size: usize, compile_max_rank: usize, compile_batch_size: usize, @@ -738,10 +748,14 @@ fn shard_manager( } // CUDA graph compilation - if compile { + if compile && !eager { shard_args.push("--compile".to_string()); } + if (compile && eager) || (!compile && !eager) { + panic!("Cannot use both --compile and --eager at the same time."); + } + // Speculative decoding if let Some(speculative_tokens) = speculative_tokens { shard_args.push("--speculative-tokens".to_string()); @@ -1303,6 +1317,7 @@ fn spawn_shards( let otlp_endpoint = args.otlp_endpoint.clone(); let quantize = args.quantize; let compile = args.compile; + let eager = args.eager; let compile_max_batch_size = args.compile_max_batch_size; let compile_max_rank = args.compile_max_rank; let compile_batch_size = args.compile_batch_size; @@ -1335,6 +1350,7 @@ fn spawn_shards( adapter_source, quantize, compile, + eager, compile_max_batch_size, compile_max_rank, compile_batch_size, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 92937511..b8b3dbc6 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1181,7 +1181,10 @@ def __init__( SLIDING_WINDOW = sliding_window SLIDING_WINDOW_BLOCKS = math.ceil(sliding_window / BLOCK_SIZE) - self.compile = compile + self.compile = compile and self.supports_cuda_graph_compilation + if compile and not self.supports_cuda_graph_compilation: + logger.info("Model does not support CUDA graph compilation. Skipping compilation.") + self.model_graph_wrapper: GraphCache = None self.kv_cache = [] diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 696dc2e8..b29338f7 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -228,6 +228,10 @@ def check_initialized(self): f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) + @property + def supports_cuda_graph_compilation(self) -> bool: + return True + @property def supports_adapter_loading(self) -> bool: return False