diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index ced57f79f4..1a1dc1e1da 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -388,10 +388,11 @@ def forward( inputs_embeds=inputs_embeds, mrope_position_ids=mrope_position_ids, ) + return hidden_states - logits = self.lm_head(hidden_states) - logits = logits.float() - return logits + def get_logits(self, hidden_states: torch.Tensor): + """compute logits of the model output.""" + return self.lm_head(hidden_states) def update_weights(self): """update weights."""