\( \newcommand{\matr}[1] {\mathbf{#1}} \newcommand{\vertbar} {\rule[-1ex]{0.5pt}{2.5ex}} \newcommand{\horzbar} {\rule[.5ex]{2.5ex}{0.5pt}} \newcommand{\E} {\mathrm{E}} \)
deepdream of
          a sidewalk

Frame registration and cell detection for 2-photon recordings

I worked on a project that involved carrying out 2-photon recording of the tectum of a zebrafish larva. For a single larva, we recorded 14 recordings, each 15 minutes long followed by an approximately 2-minute gap. The larva moved around, especially at the beginning, and to extract cell traces, we first needed a way to align frames and detect cells within and across recordings.

Concatenation of 14 recordings, each 15 minutes with gaps of ~2-minutes. The 2-photon recording ran at 2 Hz. The video is made by taking the mean of batches of 16 frames (8 secs) and concatenating these into a video. There is drift during each recording (especially the first two). There seems to be more movement between recordings, which might be due to how the 2-photon microscope operates, or perhaps the larva moves more when the laser is turned on or off.
Stack means for all 14 recordings (rec 0 to rec 13).
Per-recording stack means for all 14 recordings
Stack means for all 14 recordings. The activity of the tissue slowly decreases as the recordings progress.

I had a go at using Caiman, Suite2P and jnormcorre for frame registration (alignment). All three of them were a bit frustrating to use, and I struggled to get them to align across multiple recordings. In the end, I found much more success using keypoint detection and matching.

Alignment with keypoints

We can use feature detection and matching to create transforms between recordings, or at a finer scale, between chunks of frames within a recording.

The below function finds matching features in two images and a transformation between the two keypoint sets. By passing in the mean of a chunk of frames, we can obtain a transformation to align one chunk to another. It uses SIFT, but you could use any feature detector and descriptor extractor.

from pathlib import Path
import logging
import numpy as np
import einops
import skimage
from functools import partial
import multiprocessing as mp


_logger = logging.getLogger(__name__)


def detect_and_shift(src, target, filter_fn=None):
    """Calculate a transform to convert the src image to look like target.

    Steps:
        1. Calculate SIFT image keypoints in both images.
        2. Retain only the keypoints present in both images.
        3. Retain only the keypoints within the area we are most concerned
           about. The point of this step is that we don't care so much about
           the other areas and are fine if they shift around. If we included
           keypoints from all areas, then a more modest transform would be
           produced that goes for a consensus across the whole image. Instead,
           we are happy for less interesting parts of the image to shift so
           that we keep interesting areas as steady as possible.
        4. Use off-the-shelf RANSAC optimizer to choose a matrix transform that
           moves target points to src points. The inverse of this transform will
           be used by skimage.transform.warp to transform the src image
           towards the target image.
    """
    # 1. Extract keypoints
    # c_dog has been reduced quite a bit from the default. Reducing c_dog
    # allows lower contrast features to be included. Some of the most salient
    # areas of the images are quite low contrast, so I reduced this value.
    extractor = skimage.feature.SIFT(
        n_octaves=2,
        n_scales=5,
        sigma_min=1.0,
        sigma_in=0.1,
        c_edge=5,
        c_dog=0.005,
        upsampling=1,
    )
    extractor.detect_and_extract(src)
    kp_src, desc_src = extractor.keypoints, extractor.descriptors
    extractor.detect_and_extract(target)
    kp_target, desc_target = extractor.keypoints, extractor.descriptors

    # 2. Match and filter keypoints.
    # 2.1. Match descriptors
    matches = skimage.feature.match_descriptors(
        desc_src, desc_target, max_ratio=0.6, cross_check=True
    )
    # Check if there are enough matches
    if (
        matches.shape[0] < 2
    ):  # At least 2 points needed for translation + scaling
        raise ValueError("Not enough matches to estimate transformation.")

    # 2.2 Filter
    if filter_fn is not None:
        matches = np.array(
            [
                m
                for m in matches
                if filter_fn(kp_src[m[0]]) and filter_fn(kp_target[m[1]])
            ]
        )

    # Extract matched points
    points_src = kp_src[matches[:, 0]]
    points_target = kp_target[matches[:, 1]]

    # 3. Transform
    # 3.1 Robust affine transformation estimation using RANSAC
    model, inliers = skimage.measure.ransac(
        # https://stackoverflow.com/a/62332179/754300
        # https://github.com/scikit-image/scikit-image/issues/1749
        (np.flip(points_target, axis=-1), np.flip(points_src, axis=-1)),
        skimage.transform.AffineTransform,
        min_samples=8,
        residual_threshold=1,
        max_trials=1000,
        rng=np.random.default_rng(123),
    )

    if not inliers.any():
        raise ValueError("RANSAC failed to find a valid transformation.")

    # Alternative:
    # tform = skimage.transform.AffineTransform()
    # tform.estimate(points_target, points_src)

    # 3.2 Apply the transformation
    aligned_img = skimage.transform.warp(
        src.astype(np.float32),
        model,
        output_shape=target.shape,
        preserve_range=False,
    )

    return points_src, points_target, aligned_img, model


def clip_and_norm(img):
    """Clip and normalize. The range is determined by the 2-photon setup."""
    r = (54600, 55000)
    img = np.clip(img, a_min=r[0], a_max=r[1]) - r[0]
    img = img / (r[1] - r[0])
    return img


def mean_clip_norm(x):
    """Mean clip and norm (defined here for multiprocessing)"""
    return clip_and_norm(x.mean(axis=0))


def rec_registration(rec, ref_img=None, n_batch=1, filter_fn=None):
    """Run registration on chunks of a recording, targeting ref_img.

    The recording is split into n_batch chunks, and each chunk is
    registered against the reference image. If ref_img is not provided,
    the mean of the middle chunk is used as the reference image.

    Larger chunks mean more stable registration, as the mean of larger chunks
    is less affected by the activity of the cells. But larger chunks mean
    fewer chunks, which means fewer transformations cover the recording.

    For frames within a chunk, the transformation used is a linear interpolation
    between neighbouring transformations.
    """
    B = len(rec) // n_batch
    # l_B is our designated "mid" point.
    l_B = B // 2
    u_B = (B + 1) // 2
    print(f"{len(rec)=}, {B=}, {l_B=}, {u_B=}")
    # Get mids by edges+l_B (only need lhs, so throw away last)
    mids = np.arange(0, len(rec) + 1, B)[:-1] + l_B
    assert len(mids) == n_batch, (len(mids), n_batch, mids)
    chunks = [rec[m - l_B : m + u_B] for m in mids]
    if ref_img is None:
        """If no reference image is given, use the middle chunk."""
        ref_img = clip_and_norm(chunks[len(chunks) // 2].mean(axis=0))

    def interp(T0, T1, p0, p1, query_idx):
        """Linear interpolation between two transforms."""
        assert p1 > p0
        dT = T1 - T0
        rel = (query_idx - p0) / (p1 - p0)
        res = T0 + rel * dT
        return res

    with mp.Pool(min(len(chunks), mp.cpu_count())) as pool:
        chunk_aves = []
        pool.map_async(
            mean_clip_norm, chunks, callback=chunk_aves.extend
        ).wait()
        img_models = []

        worker_fn = partial(
            detect_and_shift, target=ref_img, filter_fn=filter_fn
        )
        pool.map_async(
            worker_fn,
            chunk_aves,
            callback=img_models.extend,
            error_callback=lambda e: _logger.error(str(e)),
        ).wait()
        if len(img_models) != len(chunk_aves):
            assert len(img_models) < len(chunk_aves)
            raise ValueError(
                "Not enough models returned. Most likely one of the chunks "
                "didn't have enough shared features with the reference image. "
                f"{len(img_models)=}, {len(chunk_aves)=}"
            )
        Ts = [x[3].params for x in img_models]
        ts = []
        warp_ts = []
        for f in range(len(rec)):
            # Don't go past the 2nd last chunk.
            t0_idx = min(n_batch - 1, f // B)
            t1_idx = min(n_batch - 1, t0_idx + 1)
            idx0 = t0_idx * B + l_B
            idx1 = t1_idx * B + l_B
            if idx0 == idx1:
                t = Ts[t0_idx]
            else:
                t = interp(Ts[t0_idx], Ts[t1_idx], idx0, idx1, f)
            warp_ts.append(t)
            # Revert the x-y coordinates introduced by inconsistencies of skimage.
            yx_t = t.copy()
            tmp = yx_t.copy()
            yx_t[0, 0] = tmp[1, 1]
            yx_t[0, 1] = tmp[1, 0]
            yx_t[0, 2] = tmp[1, 2]
            yx_t[1, 0] = tmp[0, 1]
            yx_t[1, 1] = tmp[0, 0]
            yx_t[1, 2] = tmp[0, 2]
            ts.append(yx_t)
        frames = pool.starmap(
            partial(skimage.transform.warp, preserve_range=True),
            [(rec[i].astype(np.float32), warp_ts[i]) for i in range(len(rec))],
        )
        frames = np.array(frames)
    assert frames.shape[1] == frames.shape[2] == rec.shape[1]
    ts = np.array(ts)
    return (frames, ts)

The figure below shows SIFT feature detection and matching between the average of two whole recordings.

Keypoints shared by both recording 5 and 8.
Keypoints shared by both the mean of recording 5 and of the mean of recording 8.
To align all recordings, I choose the 6th recording (rec 5) as a reference. The mean of this whole recording is the reference image that other recordings will be aligned to. The 6th recording was chosen as it is roughly in the middle of the recording sequence and had good overlap of features with other recordings. For the other recordings, I run the `detect_and_shift` function on chunks of frames, and obtain a transform for each chunk mean so as to align it with the reference image. These transforms are then interpolated to get a transform for each frame in a recording.
Recordings 2-8 aligned to recording 5. With the SIFT and RANSAC settings above, there were not enough matching keypoints to align the other recordings. We only needed 6 recordings anyway, so we didn't bother tweaking the settings further.
Aligned stack means for recordings 2-8. There is more movement in the lower part of the image (cerebellum) where we specifically did not look for features. This shows how you can select the area of the image you want to stabilize and the expense of other areas of the image.

Cell detection as blob detection

Continuing the DIY approach, next is to use blob detection for detecting cell bodies. scikit-image has a few blob detection algorithms; following the example from the docs we apply three different blob detection algorithms to the mean stack from recording 5. While we are at it, we filter in by a painted mask so to select the area we are interested in.
masked = brush_mask(in_img)
    fig1, ax = plt.subplots(1, 2)
    ax[0].imshow(in_img, cmap="gray")
    ax[1].imshow(masked, cmap="gray")
    ax[0].set_title("Input")
    ax[1].set_title("Masked")
    ax[0].set_axis_off()
    ax[1].set_axis_off()
    fig1.tight_layout()

    blobs_log = skimage.feature.blob_log(
        masked,
        min_sigma=1.2,
        max_sigma=2,
        threshold=0.035,
        overlap=0.1,
        threshold_rel=0.10,
        exclude_border=12,
    )
    blobs_dog = skimage.feature.blob_dog(
        masked,
        min_sigma=1.4,
        max_sigma=2,
        threshold=0.04,
        overlap=0.2,
        threshold_rel=0.05,
        exclude_border=12,
    )
    blobs_doh = skimage.feature.blob_doh(
        masked,
        min_sigma=0.0,
        max_sigma=3,
        threshold=0.002,
        overlap=0.2,
        threshold_rel=0.003,
    )
    # Compute radii in the 3rd column.
    blobs_log[:, 2] = blobs_log[:, 2] * math.sqrt(2)
    blobs_dog[:, 2] = blobs_dog[:, 2] * math.sqrt(2)
    blobs_list = [blobs_log, blobs_dog, blobs_doh]
    colors = ["yellow", "lime", "red"]
    titles = [
        "Laplacian of Gaussian",
        "Difference of Gaussian",
        "Determinant of Hessian",
    ]
    sequence = zip(blobs_list, colors, titles)

    fig2, axes = plt.subplots(1, 3, figsize=(12, 4), sharex=True, sharey=True)
    fig2.tight_layout()
    ax = axes.ravel()

    for idx, (blobs, color, title) in enumerate(sequence):
        ax[idx].set_title(title)
        ax[idx].imshow(in_img, cmap="gray")
        for blob in blobs:
            y, x, r = blob
            c = plt.Circle((x, y), r, color=color, linewidth=1, fill=False)
            ax[idx].add_patch(c)
        ax[idx].set_axis_off()
    return fig1, fig2
Input recording before and after masking
Recording 5 before and after applying a painted mask.
3 blob approaches applied to recording 5
Blob detection applied to the mean of recording 5.

It's worth playing around with the parameters of the blob detection algorithms to get nice results. Anecdotally, I felt the Laplacian of Gaussians worked best for our data.

There is quite a lot of flexibility in how to piece together the registration and cell detection. For our application, it made sense to run blob detection on a single central recording and use these ROIs for all other recordings. You might instead wish to run blob detection on the mean of a different chuck of frames, or on the frames of all recordings after registration. We used a central subset of the x-y space for registration, and a separate subset for cell detection, but different configurations are possible.

Transforming the blobs

The output of the blob detection algorithms is a list of circle coordinates and radii. What is nice about this is that we can transform these coordinates using the transformations calculated earlier from the registration step—this allows the cell regions of interest (the blobs) to be drawn over the original recording.

We calculate the blobs once from the mean of recording 5 and then transform the blobs for every frame using the transformations from the registration step.

def blob_rois(img):
    y_x_r = skimage.feature.blob_log(
        img,
        min_sigma=1.2,
        max_sigma=2,
        threshold=0.035,
        overlap=0.1,
        threshold_rel=0.10,
        exclude_border=12,
    )
    # Compute radii in the 3rd column.
    y_x_r[:, 2] = y_x_r[:, 2] * math.sqrt(2)
    # Sort by y, then x, ignore r.
    # In numpy, use lexsort. The last arg, y, is primary.
    sorted_indices = np.lexsort((y_x_r[:, 1], y_x_r[:, 0]))
    y_x_r = y_x_r[sorted_indices]
    return y_x_r


def transform_rois(T, rois):
    """Transform blob regions of interest with a 2D affine transform."""
    # T is a 2D affine transform.
    rois_yx = rois[:, :2]
    rois_yx_homo = np.concatenate([rois_yx, np.ones([len(rois_yx), 1])], axis=1)
    # [N, 3] <- ([3,3] @ [3, N]).T
    transformed_yx = (T @ rois_yx_homo.T).T

    # We want transform radii by scaling.
    radii = rois[:, 2]
    # Just approximate a scale
    scale_y = np.sqrt(T[0, 0] ** 2 + T[0, 1] ** 2)  # Scale along y
    scale_x = np.sqrt(T[1, 0] ** 2 + T[1, 1] ** 2)  # Scale along x
    mean_scale = (scale_y + scale_x) / 2  # Mean of the scales
    transformed_radii = radii * mean_scale

    res_rois = np.concatenate(
        [transformed_yx[:, :2], transformed_radii[:, np.newaxis]], axis=1
    )
    return res_rois
Blob regions of interest, translated by the transformations obtained from image registration. The frames are upsampled.

For flexible registration and cell detection, without the need to train any models, try some standard image processing techniques. It's far more satisfying than fighting with black box libraries.