Shifted noise for receptive field estimation
For receptive field mapping of stimulus driven neuronal responses, Gaussian checkerboard noise is popular due to its statistical properties. There is a trade-off when choosing the resolution of the grid. Small boxes increase the resolution of the estimated 2D receptive field, but as the box size is reduced, the likelihood that a group of nearby pixels will collectively elicit a response from a cell is reduced. One solution put forward is to have a large grid, and to add random offsets to this grid, with these random offsets being multiples of the desired finer resolution box size. This notebook is an empirical investigation comparing the shifted noise to the standard checkerboard noise. It provides evidence that the shifted noise cannot be considered a drop-in replacement with the same statistical properties as the vanilla checkerboard noise.
import math
from typing import Iterable
import functools
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import scienceplots
import cmocean
import PIL
import einops
import skimage
import moviepy
from moviepy import ImageSequenceClip
import io
world_height, world_width = 400, 400
box_size = 100
assert world_height % box_size == world_width % box_size == 0
noise_height = int(world_height / box_size)
noise_width = int(world_width / box_size)
num_shifts = 5
assert box_size % num_shifts == 0
small_box_size = int(box_size / num_shifts)
small_noise_height = noise_height * num_shifts
small_noise_width = noise_width * num_shifts
assert box_size % num_shifts == 0
kernel_len = 25
kernel_offset = (12, 0)
demo_len = 50
rng = np.random.default_rng(123)
def init_matplotlib():
plt.style.use(["science", "ieee"])
plt.rcParams.update({
"font.size": 8,
"legend.fontsize": 8,
"axes.titlesize": 9,
"axes.grid": False,
"figure.dpi":96*2,
"xtick.top": False})
init_matplotlib()
1. Checkerboard noise (small, big and shifting)
The vanilla checkerboard noise assigns each square a random intensity, independent across frames and grid positions.
Below we generate Gaussian checkerboard noise with two box sizes. Below that, shifted noise with a grid size equal to the larger noise grid and offsets equal to the smaller noise grid is shown.
def noise_frame(height: int, width: int, box_size : int = 1, batch_size=1):
"""Create frames of a checkerboard noise stimulus."""
noise = rng.normal(loc=0, scale=1, size=(batch_size, height, width))
if box_size > 1:
noise = np.kron(noise, np.ones(shape=(1, box_size, box_size)))
return noise
def stimulus_clip(stimulus : np.ndarray, zoom = 1, fps = 2):
expand_tensor = np.ones(shape=(1, zoom, zoom))
stimulus = np.kron(stimulus, expand_tensor)
stimulus = einops.repeat(stimulus, "b h w -> b h w 3") * 255
clip = ImageSequenceClip(list(stimulus), fps=fps)
return clip
1.1 Small boxes
small_noise_fn = functools.partial(noise_frame, small_noise_height, small_noise_width, box_size=small_box_size)
small_noise_demo = small_noise_fn(batch_size=demo_len)
stimulus_clip(small_noise_demo).display_in_notebook(rd_kwargs={'logger':None})
1.2 Big boxes
big_noise_fn = functools.partial(noise_frame, noise_height, noise_width, box_size=box_size)
big_noise_demo = big_noise_fn(batch_size=demo_len)
stimulus_clip(big_noise_demo).display_in_notebook(rd_kwargs={'logger':None})
2. Shifted noise
This checkerboard noise has boxes the same size as the boxes in the above “Big boxes” stimulus, and the grid can shift vertically and horizontally such that if all grids are overlaid on top of each other, then the resulting grid has boxes the same size as those in the “Small boxes” stimulus.
def shifted_noise(height, width, num_shifts, box_size, batch_size=1):
"""Create frames of a shifted 2D checkerboard noise stimulus."""
assert box_size % num_shifts == 0, "Box size must be divisible by num(shifts)"
shift_len = int(box_size / num_shifts)
h_with_pad = height + 1
w_with_pad = width + 1
noise = rng.normal(loc=0, scale=1, size=(batch_size, h_with_pad, w_with_pad))
expand_tensor = np.ones(shape=(1, box_size, box_size))
noise_expanded = np.kron(noise, expand_tensor)
rand_shift_y = rng.integers(low=0, high=num_shifts, size=(batch_size,1,1)) * shift_len
rand_shift_x = rng.integers(low=0, high=num_shifts, size=(batch_size,1,1)) * shift_len
ys, xs = np.mgrid[0:height*box_size, 0:width*box_size]
ys = einops.rearrange(ys, "h w -> 1 h w")
xs = einops.rearrange(xs, "h w -> 1 h w")
assert ys.shape == xs.shape == (1, height*box_size, width*box_size), ys.shape
# Broadcasting of the + operation: (1,h,w) -> (b, h, w)
ys = ys + rand_shift_y
xs = xs + rand_shift_x
assert ys.shape == xs.shape == (batch_size, height*box_size, width*box_size), ys.shape
batch_idxs = einops.rearrange(np.arange(batch_size), "b -> b 1 1")
shifted_noise = noise_expanded[batch_idxs, ys, xs]
return shifted_noise
shifted_noise_fn = functools.partial(
shifted_noise, noise_height, noise_width, num_shifts, box_size)
shifted_noise_demo = shifted_noise_fn(batch_size=demo_len)
stimulus_clip(shifted_noise_demo).display_in_notebook(rd_kwargs={'logger':None})
3. Virtual cells
Here we create linear-non-linear models by creating a 2D weight array—the cell’s receptive field. The weights are only 2D, so there isn’t any consideration of time. If the weights multiplied by the stimulus exceed a threshold value, this will be considered a “spike”. We will use these virtual cells to collect spike-triggered averages with the 3 stimulus types.
def _insert_square(arr, top_left_y, top_left_x, w, h, val=1):
"""Helper function to draw a square."""
arr[top_left_y:top_left_y+h+1, top_left_x:top_left_x+w+1] = val
return arr
def _load_receptive_field(tiff_path):
img_rgba = skimage.io.imread(tiff_path).astype(float)
img = img_rgba[:,:,0:3]
norm = np.sum(np.abs(img))
return img, norm
def centered_rec_kernel(stim_shape, h, w, offset=(0,0)):
"""Create weights for a cell with a rectangular receptive field."""
kernel = np.zeros(shape=stim_shape)
top_left_y = (kernel.shape[0]-1) // 2 - h // 2 + offset[0]
top_left_x = (kernel.shape[1]-1) // 2 - w // 2 + offset[1]
kernel = _insert_square(kernel, top_left_y, top_left_x, w, h)
return kernel
def center_surr_kernel(stim_shape, center, h, w, offset=(0,0)):
"""Create weights for a cell circular center-surround receptive field."""
center = np.array(center) + np.array(offset)
kernel = np.zeros(shape=stim_shape)
# Surround
# Surround is 4 times bigger than center, but center cuts out of big,
# so sum(surround) = 3*sum(center). Use 1/3 to balance.
rr, cc = skimage.draw.ellipse(center[0], center[1], h, w)
kernel[rr, cc] = -1/3
# Center
rr, cc = skimage.draw.ellipse(center[0], center[1], h//2, w//2)
kernel[rr, cc] = 1
return kernel
def painted_center_surr_kernel():
"""Load a painted center-surround receptive field."""
kernel, norm = _load_receptive_field('./center_surround.tiff')
# Just take one of the RGB channels, they are all the same.
kernel = kernel[:,:,0]
return kernel
def painted_curve_kernel():
"""Load a painted c-shaped receptive field."""
kernel, norm = _load_receptive_field('./curve.tiff')
# Just take one of the RGB channels, they are all the same.
kernel = kernel[:,:,0]
return kernel
def kernels():
kernel_div = 5
assert kernel_len % kernel_div == 0
mini_kernel_len = kernel_len // kernel_div
stim_shape = (world_height, world_width)
res = []
res.append(centered_rec_kernel(
stim_shape,
h=kernel_len,
w=kernel_len,
offset=kernel_offset))
res.append(centered_rec_kernel(
stim_shape,
h=mini_kernel_len*2,
w=kernel_len*6,
offset=kernel_offset))
res.append(center_surr_kernel(
stim_shape,
center=np.array([world_height, world_width])/2,
h=world_height/10,
w=world_height/10,
offset=kernel_offset))
res.append(painted_center_surr_kernel())
res.append(painted_curve_kernel())
return res
def kernel_labels():
return ("square", "rectangle", "small cc", "large cc", "curve cc")
def noise_labels():
return ("small", "big", "shifted")
def plot_kernels():
dpi = mpl.rcParams['figure.dpi']
h_in, w_in = 200 / dpi, 640*0.9 / dpi
fig, axs = plt.subplots(1, 5, figsize=(w_in, h_in), sharey=True)
vmin = -1
vmax = 1
norm = mpl.colors.Normalize(vmin,vmax)
for i,k in enumerate(kernels()):
axs[i].imshow(k, norm=norm, cmap=cmocean.cm.gray, origin="lower")
axs[i].set_yticks([0,400])
axs[i].set_xticks([0,400], [])
axs[0].set_xticklabels([0,400])
plot_kernels()

4. Spike-triggered average
Present each virtual cell a sequence of frames. A cell “spikes” if the product of its kernel times the stimulus frame is above a specified threshold.
Below, we present 2048 frames to each virtual cell. This is repeated for 4 different threshold values.
Threshold interpretation
Smaller thresholds mean that cells spike more readily, requiring less of a match between stimulus and kernel. Lower thresholds will result in the spike-triggered average being an average of more frames, but each frame is likely to be just a partial reflection of the cell’s kernel. With high thresholds, there are fewer spikes, but each triggering stimulus frame will be more representative of the full kernel.
def response(stim_frames, kernel, threshold=1):
"""Determine if the virtual cell will spike.
Args:
stim_frame: current frame of the stimulus.
kernel: the cell's weights.
threshold: the value above which a spike will be triggered.
"""
b, h, w = stim_frames.shape
assert kernel.shape == (h, w)
z = np.sum(stim_frames * kernel, axis=(1,2))
deg = z
is_on = deg >= threshold
return is_on, deg
def sta_response(noise_seq, kernel, threshold, batch_size=1024, norm_kernel=True):
"""Same as sta_spike, but return all intermediate frames."""
if norm_kernel:
knorm = np.abs(kernel).sum()
kernel = kernel/ knorm
num_timesteps = len(noise_seq)
n_batch = math.ceil(num_timesteps/batch_size)
res = np.zeros((n_batch, *kernel.shape))
spikes = np.zeros(n_batch)
n_spikes = 0
prev_res = np.zeros_like(kernel)
for b in range(n_batch):
bs = min(num_timesteps - b*batch_size, batch_size)
stim_frames = noise_seq[b*batch_size:b*batch_size+bs]
is_spike, deg = response(stim_frames, kernel, threshold)
n_new_spikes = is_spike.sum()
n_spikes += n_new_spikes
spikes[b] = n_new_spikes
if np.any(is_spike):
res[b] = prev_res * (1 - n_new_spikes/n_spikes) + stim_frames[is_spike].mean(axis=0) * n_new_spikes/n_spikes
else:
res[b] = prev_res
prev_res = res[b]
spikes = np.cumsum(spikes)
return res, spikes
def run_all_stimuli(num_timesteps, threshold, batch_size):
noise_fns = (("small", small_noise_fn), ("big", big_noise_fn), ("shifted", shifted_noise_fn))
res = {}
for noise_label, noise_fn in noise_fns:
noise = noise_fn(batch_size=num_timesteps)
for k, kernel_label in zip(kernels(), kernel_labels()):
res[(noise_label, kernel_label)] = sta_response(noise, k, threshold, batch_size)
return res
def plot(data):
n_batch = len(next(iter(data.values()))[0])
dpi = mpl.rcParams['figure.dpi']
h_in, w_in = 800 / dpi, 640*0.9 / dpi
cmap = cmocean.cm.gray
vmin = -1
vmax = 1
norm = mpl.colors.Normalize(vmin,vmax)
frames = []
for b in range(n_batch):
fig = mpl.figure.Figure(figsize=(w_in, h_in), dpi=dpi)
n_kernels = 5
axs = fig.subplots(n_kernels, 4)
for i, (k, kernel_label) in enumerate(zip(kernels(), kernel_labels())):
axs[i][0].imshow(k, norm=norm, cmap=cmocean.cm.gray, origin="lower")
for j, noise_label in enumerate(noise_labels()):
v, s = data[(noise_label, kernel_label)]
axs[i][j+1].imshow(v[b], cmap=cmap, origin="lower", norm=norm)
axs[i][j+1].text(0.27, 1.15, f"{int(s[b])} spikes", ha="left", va="top", color="black", transform=axs[i][j+1].transAxes, fontsize=6)
for (i,j), ax in np.ndenumerate(axs):
ax.set_xticks([0, 400], [])
ax.set_yticks([0, 400], [])
axs[-1][0].set_xticks([0, 400], [0, 400])
axs[-1][0].set_yticks([0, 400], [0, 400])
for j, label in enumerate(noise_labels()):
axs[0][j+1].set_title(label, pad=10)
frames.append(to_frame(fig, dpi))
plt.close()
return frames
def to_frame(fig, dpi):
"""Convert a figure to an array."""
canvas = FigureCanvasAgg(fig)
canvas.draw()
buf = canvas.buffer_rgba()
res = np.asarray(buf)
return res
data = run_all_stimuli(2048, threshold=0, batch_size=32)
ImageSequenceClip(plot(data), fps=5).display_in_notebook(rd_kwargs={'logger':None})
data = run_all_stimuli(2048, threshold=0.1, batch_size=32)
ImageSequenceClip(plot(data), fps=5).display_in_notebook(rd_kwargs={'logger':None})
data = run_all_stimuli(2048, threshold=0.2, batch_size=32)
ImageSequenceClip(plot(data), fps=5).display_in_notebook(rd_kwargs={'logger':None})
data = run_all_stimuli(2048, threshold=0.4, batch_size=32)
ImageSequenceClip(plot(data), fps=5).display_in_notebook(rd_kwargs={'logger':None})
5. Thoughts on shifted noise
I have seen the shifted noise posited as an “equivalent but better” version of the respective smaller noise. The above simulations show that they are not equivalent. At the finer grid resolution, neighbouring pixels that are independent for the small box noise are not independent for the shifted noise. The spike-triggered average will be distorted in comparison to using the box noise. The shifted noise is useful in that it can require fewer frames in order to detect the presence of a cell.
The smoother Gaussian-like results from the shifted noise is in part due to the noise itself. The prior belief that cells have Gaussian-like receptive fields should not be used as empirical evidence to support the use of the shifted noise, as the noise is, by its design, going to produce such receptive field estimates.