From 74110bcd547cd9bcaf4380d5ca5f26375c23f443 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 10 Jan 2025 00:22:06 +0530 Subject: [PATCH 1/2] fp8 dynamic activation scaling (disable static scaling) --- server/lorax_server/layers/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/layers/fp8.py b/server/lorax_server/layers/fp8.py index f03d2974a..435ccecd9 100644 --- a/server/lorax_server/layers/fp8.py +++ b/server/lorax_server/layers/fp8.py @@ -43,7 +43,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input=input, qweight=self.qweight, weight_scale=self.weight_scale, - input_scale=self.input_scale, + input_scale=None, qbias=self.qbias, ) From b38bfb0c306a7e08946321c9339b90a570540d43 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Sat, 11 Jan 2025 00:45:53 +0530 Subject: [PATCH 2/2] allow channelwise scale factors --- server/lorax_server/layers/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/layers/fp8.py b/server/lorax_server/layers/fp8.py index 435ccecd9..1641eb42a 100644 --- a/server/lorax_server/layers/fp8.py +++ b/server/lorax_server/layers/fp8.py @@ -14,7 +14,7 @@ def apply_fp8_linear( input_scale_ub: Optional[torch.Tensor] = None, qbias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=False) + qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=True) output = ops.cutlass_scaled_mm( qinput, qweight, out_dtype=input.dtype, scale_a=x_scale, scale_b=weight_scale, bias=qbias