diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e45b83a3..8996fe8c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -818,12 +818,14 @@ fn shard_manager( // Prefix caching if let Some(prefix_caching) = prefix_caching { - envs.push(("PREFIX_CACHING".into(), prefix_caching.to_string().into())); + let prefix_caching = if prefix_caching { "1" } else { "0" }; + envs.push(("PREFIX_CACHING".into(), prefix_caching.into())); } // Chunked prefill if let Some(chunked_prefill) = chunked_prefill { - envs.push(("CHUNKED_PREFILL".into(), chunked_prefill.to_string().into())); + let chunked_prefill = if chunked_prefill { "1" } else { "0" }; + envs.push(("CHUNKED_PREFILL".into(), chunked_prefill.into())); } // Compile max batch size and rank diff --git a/server/lorax_server/utils/punica.py b/server/lorax_server/utils/punica.py index fe5869e0..248359eb 100644 --- a/server/lorax_server/utils/punica.py +++ b/server/lorax_server/utils/punica.py @@ -20,14 +20,14 @@ try: import punica_kernels as _kernels - HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) + HAS_SGMV = not bool(int(os.environ.get("DISABLE_SGMV", "0"))) except ImportError: warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") _kernels = None HAS_SGMV = False -LORAX_PUNICA_TRITON_DISABLED = bool(os.environ.get("LORAX_PUNICA_TRITON_DISABLED", "")) +LORAX_PUNICA_TRITON_DISABLED = bool(int(os.environ.get("LORAX_PUNICA_TRITON_DISABLED", "0"))) if LORAX_PUNICA_TRITON_DISABLED: logger.info("LORAX_PUNICA_TRITON_DISABLED is set, disabling Punica Trion kernels.") diff --git a/server/lorax_server/utils/sources/hub.py b/server/lorax_server/utils/sources/hub.py index 536421d2..4a1d6a2f 100644 --- a/server/lorax_server/utils/sources/hub.py +++ b/server/lorax_server/utils/sources/hub.py @@ -274,7 +274,7 @@ def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[ def get_hub_api(token: Optional[str] = None) -> HfApi: - if token == "" and bool(os.environ.get("LORAX_USE_GLOBAL_HF_TOKEN", 0)): + if token == "" and bool(int(os.environ.get("LORAX_USE_GLOBAL_HF_TOKEN", "0"))): # User initialized LoRAX to fallback to global HF token if request token is empty token = os.environ.get("HUGGING_FACE_HUB_TOKEN") return HfApi(token=token) diff --git a/server/lorax_server/utils/state.py b/server/lorax_server/utils/state.py index 5566208b..d1771654 100644 --- a/server/lorax_server/utils/state.py +++ b/server/lorax_server/utils/state.py @@ -10,12 +10,12 @@ LORAX_PROFILER_DIR = os.environ.get("LORAX_PROFILER_DIR", None) -PREFIX_CACHING = bool(os.environ.get("PREFIX_CACHING", "")) -CHUNKED_PREFILL = bool(os.environ.get("CHUNKED_PREFILL", "")) +PREFIX_CACHING = bool(int(os.environ.get("PREFIX_CACHING", "0"))) +CHUNKED_PREFILL = bool(int(os.environ.get("CHUNKED_PREFILL", "0"))) LORAX_SPECULATION_MAX_BATCH_SIZE = int(os.environ.get("LORAX_SPECULATION_MAX_BATCH_SIZE", 32)) # Always use flashinfer when prefix caching is enabled -FLASH_INFER = bool(os.environ.get("FLASH_INFER", "")) or PREFIX_CACHING +FLASH_INFER = bool(int(os.environ.get("FLASH_INFER", "0"))) or PREFIX_CACHING if FLASH_INFER: logger.info("Backend = flashinfer") else: