Skip to content

Commit

Permalink
fix llava model when input images have size (x, 1)
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Aug 28, 2024
1 parent 86c9569 commit b4e0b47
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 3 additions & 2 deletions lmdeploy/vl/model/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def build_model(self):
@torch.no_grad()
def forward(self, images: List[Image]) -> List[torch.Tensor]:
"""forward."""
pixel_values = self.processor(images,
return_tensors='pt')['pixel_values']
pixel_values = self.processor(
images, return_tensors='pt',
input_data_format='channels_last')['pixel_values']
pixel_values = pixel_values.to(device=self.model.device,
dtype=self.model.dtype)
image_outputs = self.model.vision_tower.forward(
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/vl/model/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def forward(self, images: List[Image]) -> List[torch.Tensor]:
from transformers.models.llava_next.modeling_llava_next import \
image_size_to_num_patches
"""forward."""
processed_inputs = self.processor(images, return_tensors='pt')
processed_inputs = self.processor(images,
return_tensors='pt',
input_data_format='channels_last')
pixel_values = processed_inputs['pixel_values'].to(
device=self.model.device, dtype=self.model.dtype)
image_sizes = processed_inputs['image_sizes'].to(
Expand Down

0 comments on commit b4e0b47

Please sign in to comment.