Friday, June 14, 2013

Filtering an iterator into N parts lazily

I read "Filter a list into two parts" by Ned Batchelder where he mentions this solution by Peter Otten for a split into two based on the result of a boolean function (or predicate):

def partition(items, predicate=bool):
    a, b = itertools.tee((predicate(item), item) for item in items)
    return ((item for pred, item in a if not pred),
            (item for pred, item in b if pred))

Now it works, but the use of a and b seemed clumsy. I decided to make the function more functional and generalize to a split into more than three ways as a means of getting the technique to stick.

The above function uses a boolean predicate and in the return statement tests '...a if not pred' and '...b if pred'. This is the same as a test of pred==False then pred==True which because of the duality between False/True and integers 0/1 we could check pred==0 in the first case then pred==1 in the second.

So, to split an iterator into n parts we need to pass n as an argument and we swap to a predicate function that returns 0 to n-1 as the filter for each of the returned n iterators. itertools.tee takes an optional second integer argument  to allow it to tee the input iterator into n parts so, as I wrote in my comment on Neds blog post, you can do the following:

def partitionn(items, predicate=int, n=2):
    tees = itertools.tee( ((predicate(item), item)
                          for item in items), n )
    return ( (lambda i:(item for pred, item in tees[i] if pred==i))(x)
              for x in range(n) )

I left it there on Neds blog, but looking at it tonight I wanted to see if I could rid myself of the tees name it is an iterator of n iterators, but maybe I could make it even more functional?

A bit of refactoring left me with the solution below which has no names outside the generator expression, runs the predicate once on each item, and lazily works on the input iterator as well as returning lazy iterators:

>>> def partitionn2(items, predicate=int, n=2):
    return ( (lambda i, tee:(item for pred, item in tee if pred==i))(x, t)
              for x, t in enumerate(itertools.tee( ((predicate(item), item)
                                                    for item in items), n )) )

>>> partitions = partitionn(items=range(15), predicate=lambda x: x % 2, n=2)
>>> for p in partitions: print(p, list(p))

<generator object <genexpr> at 0x02D3C828> [0, 2, 4, 6, 8, 10, 12, 14]
<generator object <genexpr> at 0x02DCAEB8> [1, 3, 5, 7, 9, 11, 13]
>>> partitions = partitionn(items=range(15), predicate=lambda x: x % 4, n=4)
>>> for p in partitions: print(p, list(p))

<generator object <genexpr> at 0x02DCAD28> [0, 4, 8, 12]
<generator object <genexpr> at 0x02DCAE90> [1, 5, 9, 13]
<generator object <genexpr> at 0x02DCAFA8> [2, 6, 10, 14]
<generator object <genexpr> at 0x02DCAFD0> [3, 7, 11]
>>>


Done: