From PyTorch Tensor to PIL Image: A Line-by-Line Breakdown

Demystifying the common three-line pattern for visualizing a batch of images using torchvision, NumPy, and PIL.


The Code

When working with image generation or classification in PyTorch, you often need to visualize a batch of images. The following three-line snippet is ubiquitous:

grid = torchvision.utils.make_grid(x)
grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))

It looks compact, but each step does something important. Let’s break it down.


Step 1: Arrange Images into a Grid

grid = torchvision.utils.make_grid(x)

x is a batch of image tensors with shape [B, C, H, W] — batch size, channels, height, width.

make_grid stitches them together into a single image tensor of shape [C, H', W'], where H' and W' depend on how many images fit per row (controlled by the nrow parameter, default 8). This makes it easy to visualize an entire batch at a glance.


Step 2: Convert to a NumPy-Friendly Format

grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255

This is a chain of five operations:

Operation Purpose
.detach() Detach the tensor from the computation graph so gradients are not tracked. Required before converting to NumPy.
.cpu() Move the tensor from GPU to CPU memory. NumPy and PIL cannot operate on CUDA tensors.
.permute(1, 2, 0) Rearrange dimensions from [C, H, W] (PyTorch convention) to [H, W, C] (image library convention, e.g. RGB last).
.clip(0, 1) Clamp pixel values to the [0, 1] range. Neural network outputs may slightly exceed this range due to floating-point arithmetic.
* 255 Scale from float [0, 1] to integer [0, 255], the standard 8-bit color depth.

Why .detach() Before .cpu()?

Calling .cpu() on a tensor that still requires gradients would maintain the autograd graph across devices. .detach() first ensures a clean, gradient-free copy.

Why .permute(1, 2, 0)?

PyTorch stores images as CHW (Channel, Height, Width) because convolution operations are optimized for this layout. However, display libraries (PIL, matplotlib, OpenCV) expect HWC (Height, Width, Channel). The permute call performs a zero-copy dimension reorder.


Step 3: Create a PIL Image

grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
Operation Purpose
np.array(grid_im) Convert the PyTorch tensor to a NumPy ndarray.
.astype(np.uint8) Cast to unsigned 8-bit integer — the pixel dtype PIL expects. Without this, fromarray may throw or produce garbage.
Image.fromarray(...) Construct a PIL Image from the NumPy array. The result can be displayed with .show(), saved with .save("out.png"), or rendered inline in Jupyter with display().

The Full Pipeline at a Glance

[B, C, H, W] tensor (GPU, grad-tracked, float32)
        │  make_grid
        ▼
[C, H', W'] tensor (GPU, grad-tracked, float32)
        │  detach → cpu → permute → clip → *255
        ▼
[H', W', C] tensor (CPU, no grad, float32, 0–255)
        │  np.array → astype(uint8) → Image.fromarray
        ▼
PIL Image (RGB, uint8)

Tips

  • Normalize first. If your images are in [-1, 1] (common after transforms.Normalize), pass normalize=True to make_grid or manually remap with (x + 1) / 2 before this pipeline.
  • Use nrow for layout control. make_grid(x, nrow=4) arranges images in 4 columns.
  • padding and pad_value. make_grid(x, padding=2, pad_value=1) adds white borders between images for cleaner visualization.