\( \newcommand{\matr}[1] {\mathbf{#1}} \newcommand{\vertbar} {\rule[-1ex]{0.5pt}{2.5ex}} \newcommand{\horzbar} {\rule[.5ex]{2.5ex}{0.5pt}} \newcommand{\E} {\mathrm{E}} \)
deepdream of
          a sidewalk

Shifted noise for receptive field estimation

For receptive field mapping of stimulus driven neuronal responses, Gaussian checkerboard noise is popular due to its statistical properties. There is a trade-off when choosing the resolution of the grid. Small boxes increase the resolution of the estimated 2D receptive field, but as the box size is reduced, the likelihood that a group of nearby pixels will collectively elicit a response from a cell is reduced. One solution put forward is to have a large grid, and to add random offsets to this grid, with these random offsets being multiples of the desired finer resolution box size. This notebook is an empirical investigation comparing the shifted noise to the standard checkerboard noise. It provides evidence that the shifted noise cannot be considered a drop-in replacement with the same statistical properties as the vanilla checkerboard noise.

import math
from typing import Iterable
import functools
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import scienceplots
import cmocean
import PIL
import einops
import skimage
import moviepy
from moviepy import ImageSequenceClip
import io
world_height, world_width = 400, 400
box_size = 100
assert world_height % box_size == world_width % box_size == 0
noise_height = int(world_height / box_size)
noise_width = int(world_width / box_size)
num_shifts = 5
assert box_size % num_shifts == 0
small_box_size = int(box_size / num_shifts)
small_noise_height = noise_height * num_shifts
small_noise_width = noise_width * num_shifts
assert box_size % num_shifts == 0
kernel_len = 25
kernel_offset = (12, 0)
demo_len = 50
rng = np.random.default_rng(123)
def init_matplotlib():
    plt.style.use(["science", "ieee"])
    plt.rcParams.update({
        "font.size": 8,
        "legend.fontsize": 8,
        "axes.titlesize": 9,
        "axes.grid": False,
        "figure.dpi":96*2,
        "xtick.top": False})
init_matplotlib()

1. Checkerboard noise (small, big and shifting)

The vanilla checkerboard noise assigns each square a random intensity, independent across frames and grid positions.

Below we generate Gaussian checkerboard noise with two box sizes. Below that, shifted noise with a grid size equal to the larger noise grid and offsets equal to the smaller noise grid is shown.

def noise_frame(height: int, width: int, box_size : int = 1, batch_size=1):
    """Create frames of a checkerboard noise stimulus."""
    noise = rng.normal(loc=0, scale=1,  size=(batch_size, height, width))
    if box_size > 1:
        noise = np.kron(noise, np.ones(shape=(1, box_size, box_size)))
    return noise

def stimulus_clip(stimulus : np.ndarray, zoom = 1, fps = 2):
    expand_tensor = np.ones(shape=(1, zoom, zoom))
    stimulus = np.kron(stimulus, expand_tensor)
    stimulus = einops.repeat(stimulus, "b h w -> b h w 3") * 255
    clip = ImageSequenceClip(list(stimulus), fps=fps)
    return clip

1.1 Small boxes

small_noise_fn = functools.partial(noise_frame, small_noise_height, small_noise_width, box_size=small_box_size)
small_noise_demo = small_noise_fn(batch_size=demo_len)
stimulus_clip(small_noise_demo).display_in_notebook(rd_kwargs={'logger':None})

1.2 Big boxes

big_noise_fn = functools.partial(noise_frame, noise_height, noise_width, box_size=box_size)
big_noise_demo = big_noise_fn(batch_size=demo_len)
stimulus_clip(big_noise_demo).display_in_notebook(rd_kwargs={'logger':None})

2. Shifted noise

This checkerboard noise has boxes the same size as the boxes in the above “Big boxes” stimulus, and the grid can shift vertically and horizontally such that if all grids are overlaid on top of each other, then the resulting grid has boxes the same size as those in the “Small boxes” stimulus.

def shifted_noise(height, width, num_shifts, box_size, batch_size=1):
    """Create frames of a shifted 2D checkerboard noise stimulus."""
    assert box_size % num_shifts == 0, "Box size must be divisible by num(shifts)"
    shift_len = int(box_size / num_shifts)
    h_with_pad = height + 1
    w_with_pad = width + 1
    noise = rng.normal(loc=0, scale=1, size=(batch_size, h_with_pad, w_with_pad))
    expand_tensor = np.ones(shape=(1, box_size, box_size))
    noise_expanded = np.kron(noise, expand_tensor)
    rand_shift_y = rng.integers(low=0, high=num_shifts, size=(batch_size,1,1)) * shift_len
    rand_shift_x = rng.integers(low=0, high=num_shifts, size=(batch_size,1,1)) * shift_len
    ys, xs = np.mgrid[0:height*box_size, 0:width*box_size]
    ys = einops.rearrange(ys, "h w -> 1 h w")
    xs = einops.rearrange(xs, "h w -> 1 h w")
    assert ys.shape == xs.shape == (1, height*box_size, width*box_size), ys.shape
    # Broadcasting of the + operation: (1,h,w) -> (b, h, w)
    ys = ys + rand_shift_y
    xs = xs + rand_shift_x
    assert ys.shape == xs.shape == (batch_size, height*box_size, width*box_size), ys.shape
    batch_idxs = einops.rearrange(np.arange(batch_size), "b -> b 1 1")
    shifted_noise = noise_expanded[batch_idxs, ys, xs]
    return shifted_noise   
shifted_noise_fn = functools.partial(
    shifted_noise, noise_height, noise_width, num_shifts, box_size)
shifted_noise_demo = shifted_noise_fn(batch_size=demo_len) 
stimulus_clip(shifted_noise_demo).display_in_notebook(rd_kwargs={'logger':None})

3. Virtual cells

Here we create linear-non-linear models by creating a 2D weight array—the cell’s receptive field. The weights are only 2D, so there isn’t any consideration of time. If the weights multiplied by the stimulus exceed a threshold value, this will be considered a “spike”. We will use these virtual cells to collect spike-triggered averages with the 3 stimulus types.

def _insert_square(arr, top_left_y, top_left_x, w, h, val=1):
    """Helper function to draw a square."""
    arr[top_left_y:top_left_y+h+1, top_left_x:top_left_x+w+1] = val
    return arr

def _load_receptive_field(tiff_path):
    img_rgba = skimage.io.imread(tiff_path).astype(float)
    img = img_rgba[:,:,0:3]
    norm = np.sum(np.abs(img))
    return img, norm

def centered_rec_kernel(stim_shape, h, w, offset=(0,0)):
    """Create weights for a cell with a rectangular receptive field."""
    kernel = np.zeros(shape=stim_shape)
    top_left_y = (kernel.shape[0]-1) // 2 - h // 2 + offset[0]
    top_left_x = (kernel.shape[1]-1) // 2 - w // 2 + offset[1]
    kernel = _insert_square(kernel, top_left_y, top_left_x, w, h)
    return kernel

def center_surr_kernel(stim_shape, center, h, w, offset=(0,0)):
    """Create weights for a cell circular center-surround receptive field."""
    center = np.array(center) + np.array(offset)
    kernel = np.zeros(shape=stim_shape)
    # Surround 
    # Surround is 4 times bigger than center, but center cuts out of big, 
    # so sum(surround) = 3*sum(center). Use 1/3 to balance.
    rr, cc = skimage.draw.ellipse(center[0], center[1], h, w)
    kernel[rr, cc] = -1/3
    # Center
    rr, cc = skimage.draw.ellipse(center[0], center[1], h//2, w//2)
    kernel[rr, cc] = 1
    return kernel

def painted_center_surr_kernel():
    """Load a painted center-surround receptive field."""
    kernel, norm = _load_receptive_field('./center_surround.tiff')
    # Just take one of the RGB channels, they are all the same.
    kernel = kernel[:,:,0]
    return kernel 

def painted_curve_kernel():
    """Load a painted c-shaped receptive field."""
    kernel, norm = _load_receptive_field('./curve.tiff')
    # Just take one of the RGB channels, they are all the same.
    kernel = kernel[:,:,0]
    return kernel 
def kernels():
    kernel_div = 5
    assert kernel_len % kernel_div == 0
    mini_kernel_len = kernel_len // kernel_div
    stim_shape = (world_height, world_width)
    res = []
    res.append(centered_rec_kernel(
        stim_shape,
        h=kernel_len, 
        w=kernel_len, 
        offset=kernel_offset))
    res.append(centered_rec_kernel(
        stim_shape,
        h=mini_kernel_len*2, 
        w=kernel_len*6, 
        offset=kernel_offset))
    res.append(center_surr_kernel(
        stim_shape,
        center=np.array([world_height, world_width])/2, 
        h=world_height/10,
        w=world_height/10,
        offset=kernel_offset))
    res.append(painted_center_surr_kernel())
    res.append(painted_curve_kernel())
    return res
    
def kernel_labels():
    return ("square", "rectangle", "small cc", "large cc", "curve cc")

def noise_labels():
    return ("small", "big", "shifted")
    
def plot_kernels():
    dpi = mpl.rcParams['figure.dpi']
    h_in, w_in = 200 / dpi, 640*0.9 / dpi
    fig, axs = plt.subplots(1, 5, figsize=(w_in, h_in), sharey=True)
    vmin = -1
    vmax = 1
    norm = mpl.colors.Normalize(vmin,vmax)
    for i,k in enumerate(kernels()):
        axs[i].imshow(k, norm=norm, cmap=cmocean.cm.gray, origin="lower")
        axs[i].set_yticks([0,400])
        axs[i].set_xticks([0,400], [])
    axs[0].set_xticklabels([0,400])
plot_kernels()

png

4. Spike-triggered average

Present each virtual cell a sequence of frames. A cell “spikes” if the product of its kernel times the stimulus frame is above a specified threshold.

Below, we present 2048 frames to each virtual cell. This is repeated for 4 different threshold values.

Threshold interpretation

Smaller thresholds mean that cells spike more readily, requiring less of a match between stimulus and kernel. Lower thresholds will result in the spike-triggered average being an average of more frames, but each frame is likely to be just a partial reflection of the cell’s kernel. With high thresholds, there are fewer spikes, but each triggering stimulus frame will be more representative of the full kernel.

def response(stim_frames, kernel, threshold=1):
    """Determine if the virtual cell will spike.

    Args:
        stim_frame: current frame of the stimulus.
        kernel: the cell's weights.
        threshold: the value above which a spike will be triggered.
    """
    b, h, w = stim_frames.shape
    assert kernel.shape == (h, w)
    z = np.sum(stim_frames * kernel, axis=(1,2))
    deg = z
    is_on = deg >= threshold
    return is_on, deg
    
def sta_response(noise_seq, kernel, threshold, batch_size=1024, norm_kernel=True):
    """Same as sta_spike, but return all intermediate frames.""" 
    if norm_kernel:
        knorm = np.abs(kernel).sum()
        kernel = kernel/ knorm
    num_timesteps = len(noise_seq)
    n_batch = math.ceil(num_timesteps/batch_size)
    res = np.zeros((n_batch, *kernel.shape))
    spikes = np.zeros(n_batch)
    n_spikes = 0
    prev_res = np.zeros_like(kernel)
    for b in range(n_batch):
        bs = min(num_timesteps - b*batch_size, batch_size)
        stim_frames = noise_seq[b*batch_size:b*batch_size+bs]
        is_spike, deg = response(stim_frames, kernel, threshold)
        n_new_spikes = is_spike.sum()
        n_spikes += n_new_spikes
        spikes[b] = n_new_spikes
        if np.any(is_spike):
            res[b] = prev_res * (1 - n_new_spikes/n_spikes) + stim_frames[is_spike].mean(axis=0) * n_new_spikes/n_spikes
        else:
            res[b] = prev_res
        prev_res = res[b]
    spikes = np.cumsum(spikes)
    return res, spikes
def run_all_stimuli(num_timesteps, threshold, batch_size):
    noise_fns = (("small", small_noise_fn), ("big", big_noise_fn), ("shifted", shifted_noise_fn))
    res = {}
    for noise_label, noise_fn in noise_fns:
        noise = noise_fn(batch_size=num_timesteps)
        for k, kernel_label in zip(kernels(), kernel_labels()):
            res[(noise_label, kernel_label)] = sta_response(noise, k, threshold, batch_size)
    return res
def plot(data):
    n_batch = len(next(iter(data.values()))[0])
    dpi = mpl.rcParams['figure.dpi']
    h_in, w_in = 800 / dpi, 640*0.9 / dpi
    cmap = cmocean.cm.gray
    vmin = -1
    vmax = 1
    norm = mpl.colors.Normalize(vmin,vmax)
    frames = []
    for b in range(n_batch):
        fig = mpl.figure.Figure(figsize=(w_in, h_in), dpi=dpi)
        n_kernels = 5
        axs = fig.subplots(n_kernels, 4)
        for i, (k, kernel_label) in enumerate(zip(kernels(), kernel_labels())):
            axs[i][0].imshow(k, norm=norm, cmap=cmocean.cm.gray, origin="lower")
            for j, noise_label in enumerate(noise_labels()):
                v, s = data[(noise_label, kernel_label)]
                axs[i][j+1].imshow(v[b], cmap=cmap, origin="lower", norm=norm)
                axs[i][j+1].text(0.27, 1.15, f"{int(s[b])} spikes", ha="left", va="top", color="black", transform=axs[i][j+1].transAxes, fontsize=6)
        for (i,j), ax in np.ndenumerate(axs):
            ax.set_xticks([0, 400], [])
            ax.set_yticks([0, 400], [])
        axs[-1][0].set_xticks([0, 400], [0, 400])
        axs[-1][0].set_yticks([0, 400], [0, 400])
        for j, label in enumerate(noise_labels()):
            axs[0][j+1].set_title(label, pad=10)
        frames.append(to_frame(fig, dpi))
        plt.close()
    return frames

def to_frame(fig, dpi):
    """Convert a figure to an array."""
    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    buf = canvas.buffer_rgba()
    res = np.asarray(buf)
    return res
data = run_all_stimuli(2048, threshold=0, batch_size=32)
ImageSequenceClip(plot(data), fps=5).display_in_notebook(rd_kwargs={'logger':None})
data = run_all_stimuli(2048, threshold=0.1, batch_size=32)
ImageSequenceClip(plot(data), fps=5).display_in_notebook(rd_kwargs={'logger':None})
data = run_all_stimuli(2048, threshold=0.2, batch_size=32)
ImageSequenceClip(plot(data), fps=5).display_in_notebook(rd_kwargs={'logger':None})
data = run_all_stimuli(2048, threshold=0.4, batch_size=32)
ImageSequenceClip(plot(data), fps=5).display_in_notebook(rd_kwargs={'logger':None})

5. Thoughts on shifted noise

I have seen the shifted noise posited as an “equivalent but better” version of the respective smaller noise. The above simulations show that they are not equivalent. At the finer grid resolution, neighbouring pixels that are independent for the small box noise are not independent for the shifted noise. The spike-triggered average will be distorted in comparison to using the box noise. The shifted noise is useful in that it can require fewer frames in order to detect the presence of a cell.

The smoother Gaussian-like results from the shifted noise is in part due to the noise itself. The prior belief that cells have Gaussian-like receptive fields should not be used as empirical evidence to support the use of the shifted noise, as the noise is, by its design, going to produce such receptive field estimates.