Efficient sampling without replacement in Python
Sampling \(n\) elements without replacement from a collection of \(N\) elements means that no duplicates are allowed. A strategy for sampling without replacement is to sample with replacement, but reject already selected elements. This simple strategy is quite effective when we can expect few rejections, which is when i. \(n\) is much smaller than \(N\), and ii. the probability of choosing one element over another is small. However, when one of i. and ii. do not hold, we might be better off considering a different strategy.
Let \(\Omega\) be a finite set of size \(N\) and let \(w : \Omega \to R^+\) be a function that assigns a positive “mass” to elements in \(\Omega\). What we now are interested in is:
given a positive integer \(n \leq N\) select a size \(n\) set \(S \subseteq \Omega\) such that the probability of \(x \in S\) is proportional to \(w(x)\).
In the general case, sampling \(S\) can be done efficiently by considering a tree of events (subsets of \(\Omega\)) where the children of each event represent a partition of their parent. To sample a single element, at each node in the tree, we choose a child with probability proportional to its mass (the sum of masses of elements in the event). When we arrive at a leaf, we choose an element from that leaf with probability proportional to its mass. Having chosen the element, we delete it from all events in the tree (which are exactly the events on the path we took down the three).
In practice, we would only store a mass at each internal node in the tree, and upon choosing an element, we would subtract that mass from each internal node on the path that led to that particular element.
The above scheme is similar to Wong and Easton’s Partial Sum Tree algorithm, and costs \(O(N \log N)\) for creating the event tree, and each element can be chosen in \(O(\log N)\) time. The tree also takes \(O(N)\) space.
If we are clever, and instead of computing the tree explicitly, store only the nodes that have had their mass updated, we can get rid of the initial \(O(N \log N)\) time for constructing the tree. What we now store also only takes \(O(n \log N)\) space. In other words, \(n\) elements are sampled by the above scheme in \(O(n \log N)\) time, requiring \(O(n \log N)\) space.
In the special case where \(w\) is constant, this corresponds to sampling according to a uniform distribution of mass, we can do better using the following simple algorithm which is my slight extension to Ernvall and Nevalainen’s algorithm for unbiased random sampling (Ernvall and Nevalainen 1982):
from random import randrange
class ptdict(dict):
'''dict that returns the key without storing it if key not found'''
def __missing__(self, key):
return key
def worsample(N, exclude = []):
'''sampling without replacement, returns a generator
do not sample from exclude'''
= N
end = ptdict()
remap
for x in sorted(exclude, reverse=True):
= remap[end-1]
remap[x] -= 1
end
for i in xrange(end):
= randrange(end)
j = remap[j]
k = remap[end-1]
remap[j] -= 1
end yield k
To use the code above to get a list l
of \(n\) randomly chosen non-negative integers smaller than \(N\), we can do:
from itertools import islice
= list(islice(worsample(N), n)) l
Note that the above algorithm works on \(\Omega = \{0, 1, \ldots, N - 1\}\). It also also allows the specification of a requirement \(S \cap E = \emptyset\), where \(E \subseteq \Omega\) is given.
According to the python wiki, python dict inserts and lookups are average case \(O(1)\). Using this, we get that the expected time complexity is \(O(m \log m + n)\) where \(m = |E|\). While this seems almost too good to be true, a tree version of a dict would yield \(O(m \log m + n \log n)\), which is still good. The algorithm also uses \(n\) random numbers. The important thing to remember is that, assuming that the encoding of numbers is fixed, the algorithm time and space usage is independent of \(N\). Also, a nice feature of the above python algorithm is that it allows specifying elements to exclude from consideration.