diff --git a/server/lorax_server/layers/fp8.py b/server/lorax_server/layers/fp8.py index f03d2974a..1641eb42a 100644 --- a/server/lorax_server/layers/fp8.py +++ b/server/lorax_server/layers/fp8.py @@ -14,7 +14,7 @@ def apply_fp8_linear( input_scale_ub: Optional[torch.Tensor] = None, qbias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=False) + qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=True) output = ops.cutlass_scaled_mm( qinput, qweight, out_dtype=input.dtype, scale_a=x_scale, scale_b=weight_scale, bias=qbias @@ -43,7 +43,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input=input, qweight=self.qweight, weight_scale=self.weight_scale, - input_scale=self.input_scale, + input_scale=None, qbias=self.qbias, )