Skip to content

Commit

Permalink
added colorbar for 2d image export
Browse files Browse the repository at this point in the history
  • Loading branch information
rgerum committed Jan 31, 2024
1 parent f3d4d8d commit 6377d91
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 9 deletions.
4 changes: 4 additions & 0 deletions saenopy/gui/solver/modules/exporter/ExportRenderCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@ def get_mesh_arrows(params, result):
if data["fields"][params["arrows"]]["measure"] == "deformation":
if mesh is not None and field is not None:
return mesh, field, params["deformation_arrows"], data["fields"][params["arrows"]]["name"]
else:
return None, None, params["deformation_arrows"], data["fields"][params["arrows"]]["name"]
if data["fields"][params["arrows"]]["measure"] == "force":
if mesh is not None and field is not None:
return mesh, field, params["force_arrows"], data["fields"][params["arrows"]]["name"]
else:
return None, None, params["force_arrows"], data["fields"][params["arrows"]]["name"]
return None, None, {}, ""


Expand Down
96 changes: 89 additions & 7 deletions saenopy/gui/solver/modules/exporter/ExporterRender2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ def render_2d(params, result, exporter=None):
if pil_image is None:
return np.zeros((10, 10))

pil_image = render_2d_arrows(params, result, pil_image, im_scale, aa_scale, display_image)
pil_image, disp_params = render_2d_arrows(params, result, pil_image, im_scale, aa_scale, display_image, return_scale=True)

if aa_scale == 2:
pil_image = pil_image.resize([pil_image.width // 2, pil_image.height // 2])
aa_scale = 1

pil_image = render_2d_scalebar(params, result, pil_image, im_scale, aa_scale)
if disp_params != None:
pil_image = render_2d_colorbar(params, result, pil_image, im_scale, aa_scale, scale_max=disp_params["scale_max"], colormap=disp_params["colormap"])

pil_image = render_2d_time(params, result, pil_image)

Expand Down Expand Up @@ -76,16 +78,23 @@ def project_data(R, field, skip=1):

mesh, field, params_arrows, name = get_mesh_arrows(params, result)

if params_arrows is None:
scale_max = None
else:
scale_max = params_arrows["scale_max"] if not params_arrows["autoscale"] else None
colormap = params_arrows["colormap"]
skip = params_arrows["skip"]
alpha = params_arrows["arrow_opacity"]

if mesh is None:
if return_scale:
if scale_max is None:
return pil_image, None
else:
return pil_image, {"scale_max": scale_max, "colormap": colormap}
return pil_image, None
return pil_image

scale_max = params_arrows["scale_max"] if not params_arrows["autoscale"] else None
colormap = params_arrows["colormap"]
skip = params_arrows["skip"]
alpha = params_arrows["arrow_opacity"]

if field is not None:
# rescale and offset
scale = 1e6 / display_image[1][0]
Expand Down Expand Up @@ -133,7 +142,7 @@ def project_data(R, field, skip=1):
headlength=params["2D_arrows"]["headlength"],
headheight=params["2D_arrows"]["headheight"])
if return_scale:
return pil_image, scale_max
return pil_image, {"scale_max": scale_max, "colormap": colormap}
return pil_image


Expand Down Expand Up @@ -166,6 +175,22 @@ def getBarParameters(pixtomu, scale=1):
size_in_um=mu, color="w", unit="µm")
return pil_image

def render_2d_colorbar(params, result, pil_image, im_scale, aa_scale, colormap="viridis", scale_max=1):
pil_image = add_colorbar(pil_image, scale=1,
colormap=colormap,#params["colorbar"]["colorbar"],
#bar_width=params["colorbar"]["bar_width"] * aa_scale,
#bar_height=params["colorbar"]["bar_height"] * aa_scale,
#tick_height=params["colorbar"]["tick_height"] * aa_scale,
#tick_count=params["colorbar"]["tick_count"],
#min_v=params["scalebar"]["min_v"],
max_v=scale_max,#params["colorbar"]["max_v"],
#offset_x=params["colorbar"]["offset_x"] * aa_scale,
#offset_y=params["colorbar"]["offset_y"] * aa_scale,
#fontsize=params["colorbar"]["fontsize"] * aa_scale,
)

return pil_image


def render_2d_time(params, result, pil_image):
data = result.get_data_structure()
Expand Down Expand Up @@ -245,6 +270,63 @@ def add_text(pil_image, text, position, fontsize=18):
image.text((x, y), text, color, font=font)
return pil_image

def add_colorbar(pil_image,
colormap="viridis",
bar_width=150,
bar_height=10,
tick_height=5,
tick_count=3,
min_v=0,
max_v=10,
offset_x=15,
offset_y=-10,
scale=1, fontsize=16, color="w"):
cmap = plt.get_cmap(colormap)
if offset_x < 0:
offset_x = pil_image.size[0] + offset_x
if offset_y < 0:
offset_y = pil_image.size[1] + offset_y

color = tuple((matplotlib.colors.to_rgba_array(color)[0, :3] * 255).astype("uint8"))
if pil_image.mode != "RGB":
color = int(np.mean(color))

colors = np.zeros((bar_height, bar_width, 3), dtype=np.uint8)
for i in range(bar_width):
c = plt.get_cmap(cmap)(int(i / bar_width * 255))
colors[:, i, :] = [c[0] * 255, c[1] * 255, c[2] * 255]
pil_image.paste(Image.fromarray(colors), (offset_x, offset_y - bar_height))

image = ImageDraw.ImageDraw(pil_image)
import matplotlib.ticker as ticker

font_size = int(
round(fontsize * scale * 4 / 3)) # the 4/3 appears to be a factor of "converting" screel dpi to image dpi
try:
font = ImageFont.truetype("arial", font_size) # ImageFont.truetype("tahoma.ttf", font_size)
except IOError:
font = ImageFont.truetype("times", font_size)

locator = ticker.MaxNLocator(nbins=tick_count - 1)
#tick_positions = locator.tick_values(min_v, max_v)
tick_positions = np.linspace(min_v, max_v, tick_count)
for i, pos in enumerate(tick_positions):
x0 = offset_x + (bar_width - 2) / (tick_count - 1) * i
y0 = offset_y - bar_height - 1

image.rectangle([x0, y0-5, x0+1, y0])

text = "%d" % pos
length_number = image.textlength(text, font=font)
height_number = image.textbbox((0, 0), text, font=font)[3]

x = x0 - length_number * 0.5 + 1
y = y0 - height_number - tick_height - 3
# draw the text for the number and the unit
image.text((x, y), text, color, font=font)
#image.rectangle([pil_image.size[0]-10, 0, pil_image.size[0], 10], fill="w")
return pil_image

def add_scalebar(pil_image, scale, image_scale, width, xpos, ypos, fontsize, pixel_width, size_in_um, color="w", unit="µm"):
image = ImageDraw.ImageDraw(pil_image)
pixel_height = width
Expand Down
4 changes: 2 additions & 2 deletions saenopy/gui/spheroid/modules/DeformationDetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def update_display(self, *, plotter=None):
pil_image = pil_image.resize(
[int(pil_image.width * im_scale * aa_scale), int(pil_image.height * im_scale * aa_scale)])
#print(self.auto_scale.value(), self.getScaleMax())
pil_image, scale_max = render_2d_arrows({
pil_image, disp_params = render_2d_arrows({
'arrows': 'deformation',
'deformation_arrows': {
"autoscale": self.auto_scale.value(),
Expand All @@ -228,7 +228,7 @@ def update_display(self, *, plotter=None):
self.pixmap.setPixmap(QtGui.QPixmap(array2qimage(im)))
self.label.setExtend(im.shape[1], im.shape[0])
self.scale1.setScale([self.result.pixel_size])
self.color1.setScale(0, scale_max, self.colormap_chooser.value())
self.color1.setScale(0, disp_params["scale_max"] if disp_params else None, self.colormap_chooser.value())

if self.show_seg.value():
thresh_segmentation = self.thresh_segmentation.value()
Expand Down

0 comments on commit 6377d91

Please sign in to comment.