Frame registration and cell detection for 2-photon recordings
I worked on a project that involved carrying out 2-photon recording of the tectum of a zebrafish larva. For a single larva, we recorded 14 recordings, each 15 minutes long followed by an approximately 2-minute gap. The larva moved around, especially at the beginning, and to extract cell traces, we first needed a way to align frames and detect cells within and across recordings.


I had a go at using Caiman, Suite2P and jnormcorre for frame registration (alignment). All three of them were a bit frustrating to use, and I struggled to get them to align across multiple recordings. In the end, I found much more success using keypoint detection and matching.
Alignment with keypoints
We can use feature detection and matching to create transforms between recordings, or at a finer scale, between chunks of frames within a recording.
The below function finds matching features in two images and a transformation between the two keypoint sets. By passing in the mean of a chunk of frames, we can obtain a transformation to align one chunk to another. It uses SIFT, but you could use any feature detector and descriptor extractor.
from pathlib import Path
import logging
import numpy as np
import einops
import skimage
from functools import partial
import multiprocessing as mp
_logger = logging.getLogger(__name__)
def detect_and_shift(src, target, filter_fn=None):
"""Calculate a transform to convert the src image to look like target.
Steps:
1. Calculate SIFT image keypoints in both images.
2. Retain only the keypoints present in both images.
3. Retain only the keypoints within the area we are most concerned
about. The point of this step is that we don't care so much about
the other areas and are fine if they shift around. If we included
keypoints from all areas, then a more modest transform would be
produced that goes for a consensus across the whole image. Instead,
we are happy for less interesting parts of the image to shift so
that we keep interesting areas as steady as possible.
4. Use off-the-shelf RANSAC optimizer to choose a matrix transform that
moves target points to src points. The inverse of this transform will
be used by skimage.transform.warp to transform the src image
towards the target image.
"""
# 1. Extract keypoints
# c_dog has been reduced quite a bit from the default. Reducing c_dog
# allows lower contrast features to be included. Some of the most salient
# areas of the images are quite low contrast, so I reduced this value.
extractor = skimage.feature.SIFT(
n_octaves=2,
n_scales=5,
sigma_min=1.0,
sigma_in=0.1,
c_edge=5,
c_dog=0.005,
upsampling=1,
)
extractor.detect_and_extract(src)
kp_src, desc_src = extractor.keypoints, extractor.descriptors
extractor.detect_and_extract(target)
kp_target, desc_target = extractor.keypoints, extractor.descriptors
# 2. Match and filter keypoints.
# 2.1. Match descriptors
matches = skimage.feature.match_descriptors(
desc_src, desc_target, max_ratio=0.6, cross_check=True
)
# Check if there are enough matches
if (
matches.shape[0] < 2
): # At least 2 points needed for translation + scaling
raise ValueError("Not enough matches to estimate transformation.")
# 2.2 Filter
if filter_fn is not None:
matches = np.array(
[
m
for m in matches
if filter_fn(kp_src[m[0]]) and filter_fn(kp_target[m[1]])
]
)
# Extract matched points
points_src = kp_src[matches[:, 0]]
points_target = kp_target[matches[:, 1]]
# 3. Transform
# 3.1 Robust affine transformation estimation using RANSAC
model, inliers = skimage.measure.ransac(
# https://stackoverflow.com/a/62332179/754300
# https://github.com/scikit-image/scikit-image/issues/1749
(np.flip(points_target, axis=-1), np.flip(points_src, axis=-1)),
skimage.transform.AffineTransform,
min_samples=8,
residual_threshold=1,
max_trials=1000,
rng=np.random.default_rng(123),
)
if not inliers.any():
raise ValueError("RANSAC failed to find a valid transformation.")
# Alternative:
# tform = skimage.transform.AffineTransform()
# tform.estimate(points_target, points_src)
# 3.2 Apply the transformation
aligned_img = skimage.transform.warp(
src.astype(np.float32),
model,
output_shape=target.shape,
preserve_range=False,
)
return points_src, points_target, aligned_img, model
def clip_and_norm(img):
"""Clip and normalize. The range is determined by the 2-photon setup."""
r = (54600, 55000)
img = np.clip(img, a_min=r[0], a_max=r[1]) - r[0]
img = img / (r[1] - r[0])
return img
def mean_clip_norm(x):
"""Mean clip and norm (defined here for multiprocessing)"""
return clip_and_norm(x.mean(axis=0))
def rec_registration(rec, ref_img=None, n_batch=1, filter_fn=None):
"""Run registration on chunks of a recording, targeting ref_img.
The recording is split into n_batch chunks, and each chunk is
registered against the reference image. If ref_img is not provided,
the mean of the middle chunk is used as the reference image.
Larger chunks mean more stable registration, as the mean of larger chunks
is less affected by the activity of the cells. But larger chunks mean
fewer chunks, which means fewer transformations cover the recording.
For frames within a chunk, the transformation used is a linear interpolation
between neighbouring transformations.
"""
B = len(rec) // n_batch
# l_B is our designated "mid" point.
l_B = B // 2
u_B = (B + 1) // 2
print(f"{len(rec)=}, {B=}, {l_B=}, {u_B=}")
# Get mids by edges+l_B (only need lhs, so throw away last)
mids = np.arange(0, len(rec) + 1, B)[:-1] + l_B
assert len(mids) == n_batch, (len(mids), n_batch, mids)
chunks = [rec[m - l_B : m + u_B] for m in mids]
if ref_img is None:
"""If no reference image is given, use the middle chunk."""
ref_img = clip_and_norm(chunks[len(chunks) // 2].mean(axis=0))
def interp(T0, T1, p0, p1, query_idx):
"""Linear interpolation between two transforms."""
assert p1 > p0
dT = T1 - T0
rel = (query_idx - p0) / (p1 - p0)
res = T0 + rel * dT
return res
with mp.Pool(min(len(chunks), mp.cpu_count())) as pool:
chunk_aves = []
pool.map_async(
mean_clip_norm, chunks, callback=chunk_aves.extend
).wait()
img_models = []
worker_fn = partial(
detect_and_shift, target=ref_img, filter_fn=filter_fn
)
pool.map_async(
worker_fn,
chunk_aves,
callback=img_models.extend,
error_callback=lambda e: _logger.error(str(e)),
).wait()
if len(img_models) != len(chunk_aves):
assert len(img_models) < len(chunk_aves)
raise ValueError(
"Not enough models returned. Most likely one of the chunks "
"didn't have enough shared features with the reference image. "
f"{len(img_models)=}, {len(chunk_aves)=}"
)
Ts = [x[3].params for x in img_models]
ts = []
warp_ts = []
for f in range(len(rec)):
# Don't go past the 2nd last chunk.
t0_idx = min(n_batch - 1, f // B)
t1_idx = min(n_batch - 1, t0_idx + 1)
idx0 = t0_idx * B + l_B
idx1 = t1_idx * B + l_B
if idx0 == idx1:
t = Ts[t0_idx]
else:
t = interp(Ts[t0_idx], Ts[t1_idx], idx0, idx1, f)
warp_ts.append(t)
# Revert the x-y coordinates introduced by inconsistencies of skimage.
yx_t = t.copy()
tmp = yx_t.copy()
yx_t[0, 0] = tmp[1, 1]
yx_t[0, 1] = tmp[1, 0]
yx_t[0, 2] = tmp[1, 2]
yx_t[1, 0] = tmp[0, 1]
yx_t[1, 1] = tmp[0, 0]
yx_t[1, 2] = tmp[0, 2]
ts.append(yx_t)
frames = pool.starmap(
partial(skimage.transform.warp, preserve_range=True),
[(rec[i].astype(np.float32), warp_ts[i]) for i in range(len(rec))],
)
frames = np.array(frames)
assert frames.shape[1] == frames.shape[2] == rec.shape[1]
ts = np.array(ts)
return (frames, ts)
The figure below shows SIFT feature detection and matching between the average of two whole recordings.


Cell detection as blob detection
Continuing the DIY approach, next is to use blob detection for detecting cell bodies. scikit-image has a few blob detection algorithms; following the example from the docs we apply three different blob detection algorithms to the mean stack from recording 5. While we are at it, we filter in by a painted mask so to select the area we are interested in.masked = brush_mask(in_img)
fig1, ax = plt.subplots(1, 2)
ax[0].imshow(in_img, cmap="gray")
ax[1].imshow(masked, cmap="gray")
ax[0].set_title("Input")
ax[1].set_title("Masked")
ax[0].set_axis_off()
ax[1].set_axis_off()
fig1.tight_layout()
blobs_log = skimage.feature.blob_log(
masked,
min_sigma=1.2,
max_sigma=2,
threshold=0.035,
overlap=0.1,
threshold_rel=0.10,
exclude_border=12,
)
blobs_dog = skimage.feature.blob_dog(
masked,
min_sigma=1.4,
max_sigma=2,
threshold=0.04,
overlap=0.2,
threshold_rel=0.05,
exclude_border=12,
)
blobs_doh = skimage.feature.blob_doh(
masked,
min_sigma=0.0,
max_sigma=3,
threshold=0.002,
overlap=0.2,
threshold_rel=0.003,
)
# Compute radii in the 3rd column.
blobs_log[:, 2] = blobs_log[:, 2] * math.sqrt(2)
blobs_dog[:, 2] = blobs_dog[:, 2] * math.sqrt(2)
blobs_list = [blobs_log, blobs_dog, blobs_doh]
colors = ["yellow", "lime", "red"]
titles = [
"Laplacian of Gaussian",
"Difference of Gaussian",
"Determinant of Hessian",
]
sequence = zip(blobs_list, colors, titles)
fig2, axes = plt.subplots(1, 3, figsize=(12, 4), sharex=True, sharey=True)
fig2.tight_layout()
ax = axes.ravel()
for idx, (blobs, color, title) in enumerate(sequence):
ax[idx].set_title(title)
ax[idx].imshow(in_img, cmap="gray")
for blob in blobs:
y, x, r = blob
c = plt.Circle((x, y), r, color=color, linewidth=1, fill=False)
ax[idx].add_patch(c)
ax[idx].set_axis_off()
return fig1, fig2


It's worth playing around with the parameters of the blob detection algorithms to get nice results. Anecdotally, I felt the Laplacian of Gaussians worked best for our data.
There is quite a lot of flexibility in how to piece together the registration and cell detection. For our application, it made sense to run blob detection on a single central recording and use these ROIs for all other recordings. You might instead wish to run blob detection on the mean of a different chuck of frames, or on the frames of all recordings after registration. We used a central subset of the x-y space for registration, and a separate subset for cell detection, but different configurations are possible.
Transforming the blobs
The output of the blob detection algorithms is a list of circle coordinates and radii. What is nice about this is that we can transform these coordinates using the transformations calculated earlier from the registration step—this allows the cell regions of interest (the blobs) to be drawn over the original recording.
We calculate the blobs once from the mean of recording 5 and then transform the blobs for every frame using the transformations from the registration step.
def blob_rois(img):
y_x_r = skimage.feature.blob_log(
img,
min_sigma=1.2,
max_sigma=2,
threshold=0.035,
overlap=0.1,
threshold_rel=0.10,
exclude_border=12,
)
# Compute radii in the 3rd column.
y_x_r[:, 2] = y_x_r[:, 2] * math.sqrt(2)
# Sort by y, then x, ignore r.
# In numpy, use lexsort. The last arg, y, is primary.
sorted_indices = np.lexsort((y_x_r[:, 1], y_x_r[:, 0]))
y_x_r = y_x_r[sorted_indices]
return y_x_r
def transform_rois(T, rois):
"""Transform blob regions of interest with a 2D affine transform."""
# T is a 2D affine transform.
rois_yx = rois[:, :2]
rois_yx_homo = np.concatenate([rois_yx, np.ones([len(rois_yx), 1])], axis=1)
# [N, 3] <- ([3,3] @ [3, N]).T
transformed_yx = (T @ rois_yx_homo.T).T
# We want transform radii by scaling.
radii = rois[:, 2]
# Just approximate a scale
scale_y = np.sqrt(T[0, 0] ** 2 + T[0, 1] ** 2) # Scale along y
scale_x = np.sqrt(T[1, 0] ** 2 + T[1, 1] ** 2) # Scale along x
mean_scale = (scale_y + scale_x) / 2 # Mean of the scales
transformed_radii = radii * mean_scale
res_rois = np.concatenate(
[transformed_yx[:, :2], transformed_radii[:, np.newaxis]], axis=1
)
return res_rois
For flexible registration and cell detection, without the need to train any models, try some standard image processing techniques. It's far more satisfying than fighting with black box libraries.