| import numpy as np |
| import math |
| import PIL |
|
|
| def postprocess(x): |
| """[0,1] to uint8.""" |
| |
| x = np.clip(255 * x, 0, 255) |
| x = np.cast[np.uint8](x) |
| return x |
|
|
| def tile(X, rows, cols): |
| """Tile images for display.""" |
| tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype) |
| for i in range(rows): |
| for j in range(cols): |
| idx = i * cols + j |
| if idx < X.shape[0]: |
| img = X[idx,...] |
| tiling[ |
| i*X.shape[1]:(i+1)*X.shape[1], |
| j*X.shape[2]:(j+1)*X.shape[2], |
| :] = img |
| return tiling |
|
|
|
|
| def plot_batch(X, out_path): |
| """Save batch of images tiled.""" |
| n_channels = X.shape[3] |
| if n_channels > 3: |
| X = X[:,:,:,np.random.choice(n_channels, size = 3)] |
| X = postprocess(X) |
| rc = math.sqrt(X.shape[0]) |
| rows = cols = math.ceil(rc) |
| canvas = tile(X, rows, cols) |
| canvas = np.squeeze(canvas) |
| PIL.Image.fromarray(canvas).save(out_path) |