From PyTorch Tensor to PIL Image: A Line-by-Line Breakdown
From PyTorch Tensor to PIL Image: A Line-by-Line Breakdown
Demystifying the common three-line pattern for visualizing a batch of images using
torchvision,NumPy, andPIL.
The Code
When working with image generation or classification in PyTorch, you often need to visualize a batch of images. The following three-line snippet is ubiquitous:
grid = torchvision.utils.make_grid(x)
grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
It looks compact, but each step does something important. Let’s break it down.
Step 1: Arrange Images into a Grid
grid = torchvision.utils.make_grid(x)
x is a batch of image tensors with shape [B, C, H, W] — batch size, channels, height, width.
make_grid stitches them together into a single image tensor of shape [C, H', W'], where H' and W' depend on how many images fit per row (controlled by the nrow parameter, default 8). This makes it easy to visualize an entire batch at a glance.
Step 2: Convert to a NumPy-Friendly Format
grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
This is a chain of five operations:
| Operation | Purpose |
|---|---|
.detach() |
Detach the tensor from the computation graph so gradients are not tracked. Required before converting to NumPy. |
.cpu() |
Move the tensor from GPU to CPU memory. NumPy and PIL cannot operate on CUDA tensors. |
.permute(1, 2, 0) |
Rearrange dimensions from [C, H, W] (PyTorch convention) to [H, W, C] (image library convention, e.g. RGB last). |
.clip(0, 1) |
Clamp pixel values to the [0, 1] range. Neural network outputs may slightly exceed this range due to floating-point arithmetic. |
* 255 |
Scale from float [0, 1] to integer [0, 255], the standard 8-bit color depth. |
Why .detach() Before .cpu()?
Calling .cpu() on a tensor that still requires gradients would maintain the autograd graph across devices. .detach() first ensures a clean, gradient-free copy.
Why .permute(1, 2, 0)?
PyTorch stores images as CHW (Channel, Height, Width) because convolution operations are optimized for this layout. However, display libraries (PIL, matplotlib, OpenCV) expect HWC (Height, Width, Channel). The permute call performs a zero-copy dimension reorder.
Step 3: Create a PIL Image
grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
| Operation | Purpose |
|---|---|
np.array(grid_im) |
Convert the PyTorch tensor to a NumPy ndarray. |
.astype(np.uint8) |
Cast to unsigned 8-bit integer — the pixel dtype PIL expects. Without this, fromarray may throw or produce garbage. |
Image.fromarray(...) |
Construct a PIL Image from the NumPy array. The result can be displayed with .show(), saved with .save("out.png"), or rendered inline in Jupyter with display(). |
The Full Pipeline at a Glance
[B, C, H, W] tensor (GPU, grad-tracked, float32)
│ make_grid
▼
[C, H', W'] tensor (GPU, grad-tracked, float32)
│ detach → cpu → permute → clip → *255
▼
[H', W', C] tensor (CPU, no grad, float32, 0–255)
│ np.array → astype(uint8) → Image.fromarray
▼
PIL Image (RGB, uint8)
Tips
- Normalize first. If your images are in
[-1, 1](common aftertransforms.Normalize), passnormalize=Truetomake_gridor manually remap with(x + 1) / 2before this pipeline. - Use
nrowfor layout control.make_grid(x, nrow=4)arranges images in 4 columns. paddingandpad_value.make_grid(x, padding=2, pad_value=1)adds white borders between images for cleaner visualization.