Tutorial about drift analysis and correction

Lateral drift correction is useful in most SMLM experiments. To determine the amount of drift a method based on image cross-correlation or an iterative closest point algorithm can be applied.

We demonstrate drift analysis and correction on simulated data.

from pathlib import Path

%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats

import locan as lc
lc.show_versions(system=False, dependencies=False, verbose=False)
Locan:
   version: 0.22.0.dev32+g4bfc3ab8b

Python:
   version: 3.11.14

Synthetic data

We use synthetic data that follows a Neyman-Scott spatial distribution (blobs). The intensity values are exponentially distributed and the number of localizations per frame follows a Poisson distribution:

rng = np.random.default_rng(seed=1)
intensity_mean = 1000
localizations_per_frame_mean = 3
dat_blob = lc.simulate_Thomas(parent_intensity=1e-4, region=((0, 1000), (0, 1000)), cluster_mu=1000, cluster_std=10, seed=rng)
dat_blob.dataframe['intensity'] = stats.expon.rvs(scale=intensity_mean, size=len(dat_blob), loc=500)
dat_blob.dataframe['frame'] = lc.simulate_frame_numbers(n_samples=len(dat_blob), lam=localizations_per_frame_mean, seed=rng)

dat_blob = lc.LocData.from_dataframe(dataframe=dat_blob.data)

print('Data head:')
print(dat_blob.data.head(), '\n')
print('Summary:')
dat_blob.print_summary()
print('Properties:')
print(dat_blob.properties)
Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
Data head:
   position_x  position_y  cluster_label    intensity  frame
0  915.763326  465.421770             28  4232.493866      0
1  729.786200   81.672520             56   929.007997      1
2  869.708004    9.266125             41   867.784458      1
3  521.743493   54.508814             14  2124.239866      2
4  430.972011  375.461311             61  2188.359508      3 

Summary:
identifier: "2"
comment: ""
source: DESIGN
state: RAW
element_count: 98201
frame_count: 31016
creation_time {
  2026-04-30T08:33:22.672057Z
}

Properties:
{'localization_count': 98201, 'position_x': np.float64(495.13588743804684), 'uncertainty_x': np.float64(0.8971814298938875), 'position_y': np.float64(507.65355920347866), 'uncertainty_y': np.float64(0.8964573205917536), 'intensity': np.float64(147154096.15695238), 'frame': np.int64(0), 'region_measure_bb': np.float64(999959.2361869529), 'localization_density_bb': np.float64(0.09820500321039116), 'subregion_measure_bb': np.float64(3999.918472062067)}
lc.render_2d(dat_blob, bin_size=10, rescale='equal');
../../_images/3c7f1a320b6bb972c8a17bbf36621a9cbebd78416ca21ac7d8d7edb7178f20c8.png

Add linear drift

We add linear drift with a velocity given in length units per frame.

dat_blob_with_drift = lc.add_drift(dat_blob, velocity=(0.002, 0.001), seed=rng)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
lc.render_2d(dat_blob_with_drift, ax=axes[0], bin_size=10);
lc.render_2d(dat_blob_with_drift, ax=axes[1], bin_size=2, rescale='equal', bin_range=((0, 500),(0, 500)));
lc.render_2d_mpl(dat_blob_with_drift, ax=axes[2], other_property='frame', bin_size=2, bin_range=((0, 500),(0, 500)), cmap='viridis');
../../_images/d5807fa56bb75bd80f776281e3b56bcb9129aa8d45ef099300d012beab6175a3.png

Estimate RMS errors

Knowing the ground truth, you can define a root mean squared error between the original localization coordinates and those after drift and later after correction.

def rmse(locdata, other_locdata):
    return np.sqrt(np.mean(np.square(np.subtract(locdata.coordinates, other_locdata.coordinates)), axis=0))
rmse(dat_blob, dat_blob_with_drift).round(2)
array([37.76, 18.88])

Estimate drift

Drift can be estimated by comparing different chunks of successive localizations using either an “iterative closest point” algorithm or a “cross-correlation” algorithm. Per default, the icp algorithm is applied.

%%time
drift = lc.Drift(chunk_size=10_000, target='first', method='icp').compute(dat_blob_with_drift)
CPU times: user 3.35 s, sys: 7.23 ms, total: 3.36 s
Wall time: 1.89 s

Transformations to register the different data chunks are represented by a transformation matrix and a transformation offset that together specifiy an affine transformation. The tansformation parameters are kept under the transformations attribute.

drift.transformations
[Transformation(matrix=array([[1., 0.],
        [0., 1.]]), offset=array([0., 0.])),
 Transformation(matrix=array([[ 9.98792443e-01,  4.43801557e-04],
        [-4.43801557e-04,  9.98792443e-01]]), offset=array([-6.56553239, -2.7171172 ])),
 Transformation(matrix=array([[ 0.99716405,  0.00268322],
        [-0.00268322,  0.99716405]]), offset=array([-13.15345296,  -4.05470652])),
 Transformation(matrix=array([[ 9.97764214e-01,  5.81845172e-04],
        [-5.81845172e-04,  9.97764214e-01]]), offset=array([-19.46087507,  -8.60939708])),
 Transformation(matrix=array([[ 9.98398817e-01,  6.25680278e-04],
        [-6.25680278e-04,  9.98398817e-01]]), offset=array([-25.94756073, -12.08151622])),
 Transformation(matrix=array([[ 9.98495460e-01, -4.65982241e-04],
        [ 4.65982241e-04,  9.98495460e-01]]), offset=array([-32.71859007, -15.76861207])),
 Transformation(matrix=array([[ 0.99857889,  0.00120287],
        [-0.00120287,  0.99857889]]), offset=array([-39.95058187, -18.65026698])),
 Transformation(matrix=array([[ 9.96629116e-01, -5.37378964e-04],
        [ 5.37378964e-04,  9.96629116e-01]]), offset=array([-44.18251022, -22.15203275])),
 Transformation(matrix=array([[ 9.98851114e-01, -4.33292180e-04],
        [ 4.33292180e-04,  9.98851114e-01]]), offset=array([-52.9063275 , -26.00989829])),
 Transformation(matrix=array([[ 9.97771654e-01, -6.47144209e-04],
        [ 6.47144209e-04,  9.97771654e-01]]), offset=array([-58.14458854, -28.33975215]))]

The parameters can be visualized using the plot function. The matrix in this case is close to the unit matrix.

drift.plot(transformation_component='matrix', element=None);
plt.legend();
../../_images/e63743b374ef76e25eb18ac9447d4907e9d3d14a4817ac28cb4a1a2916fa7dd6.png
drift.plot(transformation_component='offset', element=None)
plt.legend();
../../_images/df292641b97d897c5d00dfa1fbf1072dcb11831e99b5f0eb11ceee0618c3cfdf.png

Model drift

A continuous transformation model as function of frame number is estimated by fitting the individual transformation components with the specified fit models. Fit models can be provided as DriftComponent or by a string representing standard model functions.

from lmfit.models import ConstantModel, LinearModel, PolynomialModel

drift.fit_transformations(slice_data=slice(None), offset_models=(lc.DriftComponent('spline', s=100), 'linear'), verbose=True);
../../_images/1d87b5e38746ba5a6d4eae47074dd4ca806985afc189a4a5d5a30518ef3b066c.png ../../_images/9c89c63c75a8c099946121b9cb1ad4cc433633e69588e28355ef5464b70f72de.png

The fit models are represented as DriftComponent and can be accessed through the transformation_models attribute.

drift.transformation_models
{'matrix': None,
 'offset': [<locan.analysis.drift.DriftComponent at 0x7cdb0a972010>,
  <locan.analysis.drift.DriftComponent at 0x7cdb0a7ce210>]}
drift.transformation_models['offset'][0].type
'spline'
drift.transformation_models['offset'][0].eval(0)
array(3.33056629)

Each DriftModel carries detailed information about the fit under the model_result attribute. In most cases, except splines, this will be a lmfit.ModelResult object.

drift.transformation_models['offset'][0].model_result
(array([ 1639.0867    ,  1639.0867    ,  1639.0867    ,  1639.0867    ,
        31356.39897573, 31356.39897573, 31356.39897573, 31356.39897573]),
 array([ 3.39564863e-02, -1.98083279e+01, -3.82432449e+01, -5.82808309e+01,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00]),
 3)
drift.transformation_models['offset'][1].type
'linear'
drift.transformation_models['offset'][1].model_result

Fit Result

Model: Model(linear)

Drift correction

The estimated drift is corrected by applying a transformation on the localization chunks (from_model=False).

%%time
drift.apply_correction(from_model=False);
CPU times: user 279 ms, sys: 6.07 ms, total: 285 ms
Wall time: 284 ms

The same correction can be applied to any other localization dataset.

fig, axes = plt.subplots(1, 2, figsize=(15, 5))
lc.render_2d(drift.locdata_corrected, ax=axes[0], bin_size=2, rescale='equal', bin_range=((0, 200),(0, 200)));
lc.render_2d_mpl(drift.locdata_corrected, ax=axes[1], other_property='frame', bin_size=2, bin_range=((0, 200),(0, 200)), cmap='viridis');
../../_images/e5d86c7c3d48e2eead3df5ae14a98b06f0d7e75489bc3ef39faa0d4c30e23212.png
rmse(dat_blob, drift.locdata_corrected).round(2)
array([9.55, 4.82])

Or the estimated drift is corrected by applying a transformation on each individual localization using the drift models (from_model=True).

%%time
drift.apply_correction(from_model=True)
CPU times: user 45.4 ms, sys: 3.95 ms, total: 49.3 ms
Wall time: 49.1 ms
Drift(chunks=None, chunk_size=10000, n_chunks=None, target=first, method=icp, kwargs_chunk=None, kwargs_register=None)
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
lc.render_2d(drift.locdata_corrected, ax=axes[0], bin_size=2, rescale='equal', bin_range=((0, 200),(0, 200)));
lc.render_2d_mpl(drift.locdata_corrected, ax=axes[1], other_property='frame', bin_size=2, bin_range=((0, 200),(0, 200)), cmap='viridis');
../../_images/de62b1405a0ab129795650132ba26588138317ad4891acc4fafbd1d16205fd20.png
rmse(dat_blob, drift.locdata_corrected).round(2)
array([3.93, 2.77])
drift.locdata_corrected.meta
identifier: "26"
source: DESIGN
state: MODIFIED
history {
  name: "LocData.from_dataframe"
}
history {
  name: "add_drift"
  parameter: "{\'locdata\': <locan.data.locdata.LocData object at 0x7cdb297a8d10>, \'diffusion_constant\': None, \'velocity\': (0.002, 0.001), \'seed\': Generator(PCG64) at 0x7CDB337B35A0}"
}
history {
  name: "apply_correction"
  parameter: "{\'self\': Drift(chunks=None, chunk_size=10000, n_chunks=None, target=first, method=icp, kwargs_chunk=None, kwargs_register=None), \'locdata\': None, \'from_model\': True}"
}
ancestor_identifiers: "2"
ancestor_identifiers: "3"
element_count: 98201
frame_count: 31016
creation_time {
  seconds: 1777538002
  nanos: 672057000
}
modification_time {
  seconds: 1777538007
  nanos: 141515000
}

Drift analysis by a cross-correlation algorithm

The same kind of drift estimation and correction can be applied using the image cross-correlation algorithm.

%%time
drift = lc.Drift(chunk_size=10_000, target='first', method='cc').\
        compute(dat_blob_with_drift).\
        fit_transformations(slice_data=slice(None), offset_models=(LinearModel(), LinearModel()), verbose=True).\
        apply_correction(from_model=True);
/home/docs/checkouts/readthedocs.org/user_builds/locan/envs/latest/lib/python3.11/site-packages/locan/analysis/drift.py:265: UserWarning: The function register_cc has been refactored. The kwargs max_offset and verbose are deprecated . It now calls _register_cc_skimage. Use _register_cc_picasso for legacy behavior.
  transformation = register_cc(
/home/docs/checkouts/readthedocs.org/user_builds/locan/envs/latest/lib/python3.11/site-packages/locan/analysis/drift.py:265: UserWarning: The function register_cc has been refactored. The kwargs max_offset and verbose are deprecated . It now calls _register_cc_skimage. Use _register_cc_picasso for legacy behavior.
  transformation = register_cc(
/home/docs/checkouts/readthedocs.org/user_builds/locan/envs/latest/lib/python3.11/site-packages/locan/analysis/drift.py:265: UserWarning: The function register_cc has been refactored. The kwargs max_offset and verbose are deprecated . It now calls _register_cc_skimage. Use _register_cc_picasso for legacy behavior.
  transformation = register_cc(
/home/docs/checkouts/readthedocs.org/user_builds/locan/envs/latest/lib/python3.11/site-packages/locan/analysis/drift.py:265: UserWarning: The function register_cc has been refactored. The kwargs max_offset and verbose are deprecated . It now calls _register_cc_skimage. Use _register_cc_picasso for legacy behavior.
  transformation = register_cc(
/home/docs/checkouts/readthedocs.org/user_builds/locan/envs/latest/lib/python3.11/site-packages/locan/analysis/drift.py:265: UserWarning: The function register_cc has been refactored. The kwargs max_offset and verbose are deprecated . It now calls _register_cc_skimage. Use _register_cc_picasso for legacy behavior.
  transformation = register_cc(
/home/docs/checkouts/readthedocs.org/user_builds/locan/envs/latest/lib/python3.11/site-packages/locan/analysis/drift.py:265: UserWarning: The function register_cc has been refactored. The kwargs max_offset and verbose are deprecated . It now calls _register_cc_skimage. Use _register_cc_picasso for legacy behavior.
  transformation = register_cc(
/home/docs/checkouts/readthedocs.org/user_builds/locan/envs/latest/lib/python3.11/site-packages/locan/analysis/drift.py:265: UserWarning: The function register_cc has been refactored. The kwargs max_offset and verbose are deprecated . It now calls _register_cc_skimage. Use _register_cc_picasso for legacy behavior.
  transformation = register_cc(
/home/docs/checkouts/readthedocs.org/user_builds/locan/envs/latest/lib/python3.11/site-packages/locan/analysis/drift.py:265: UserWarning: The function register_cc has been refactored. The kwargs max_offset and verbose are deprecated . It now calls _register_cc_skimage. Use _register_cc_picasso for legacy behavior.
  transformation = register_cc(
/home/docs/checkouts/readthedocs.org/user_builds/locan/envs/latest/lib/python3.11/site-packages/locan/analysis/drift.py:265: UserWarning: The function register_cc has been refactored. The kwargs max_offset and verbose are deprecated . It now calls _register_cc_skimage. Use _register_cc_picasso for legacy behavior.
  transformation = register_cc(
CPU times: user 451 ms, sys: 9.02 ms, total: 461 ms
Wall time: 460 ms
../../_images/cb8b4edd1de17d048a4f472eb4fee173ef1059381fe6561cad17a6e13c2bd8cd.png ../../_images/e6c1f684e1401d1f6abcde899f33bf7008b5aa945c5f86b4742c4ba4eb1a2f7b.png
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
lc.render_2d(drift.locdata_corrected, ax=axes[0], bin_size=2, rescale='equal', bin_range=((0, 200),(0, 200)));
lc.render_2d_mpl(drift.locdata_corrected, ax=axes[1], other_property='frame', bin_size=2, bin_range=((0, 200),(0, 200)), cmap='viridis');
../../_images/10d1da61a81b099cd2eedda7d95168f533a6cb96a940a95dab52a652a83507ac.png
rmse(dat_blob, drift.locdata_corrected)
array([3.22658567, 1.66731276])