numpy – Applying Minimum Image Convention in Python

I am computing pairwise Euclidean distances for 3-D vectors representing particle positions in a periodic system. The minimum image convention is applied for each periodic boundary such that a particle only considers the nearest image of another particle when computing the distance.

This code is part of a larger post-processing effort in Python. Here is a working example of what I am implementing:

import numpy as np
from scipy.spatial.distance import cdist

s = np.array((40, 30, 20))
half_s = 0.5*s

a = np.transpose(np.random.rand(1000,3) * s)
b = np.transpose(np.random.rand(1000,3) * s)

dists = np.empty((3, a.shape(1)*b.shape(1)))
for i in range(3):
    dists(i,:) = cdist(a(i,:).reshape(-1,1), b(i,:).reshape(-1,1), 'cityblock').ravel()
    dists(i,:) = np.where(dists(i,:) > half_s(i), dists(i,:) - s(i), dists(i,:))
dists = np.sqrt(np.einsum("ij,ij->j", dists, dists))

The domain size s and the 3xn particle position arrays a and b are obtained from existing data structures, but this example uses sizes I would typically expect. I should emphasize that a and b can have different lengths, but on average I expect the final dists array to represent around a million distances (+/- an order of magnitude).

The last 5 lines computing the distances will need to be run many thousands of times, so this is what I hope to optimize.

The difficulty arises from the need to apply the minimum image convention to each component independently. I haven’t been able to find anything which can beat SciPy’s cdist for computing the unsigned distance components, and NumPy’s einsum function seems to be the most efficient way to reduce the distances for arrays of this size. At this point, the bottleneck in the code is in the NumPy where function. I also tried using NumPy’s casting method by replacing the penultimate line with

dists(i,:) -= (dists(i,:) * s_r(i) + 0.5).astype(int) * s(i)

where s_r = 1/s, but this yielded the same runtime. This paper discusses various techniques to handle this operation in C/C++, but I’m not familiar enough with the underlying CPython implementation in NumPy/SciPy to determine what’s best here. I’d love to parallelize this section, but I had little success with the multiprocessing module and cdist is incompatible with Numba. I would also entertain suggestions on writing a C extension, though I’ve never incorporated one in Python before.