イテレータをバッチ化する

[1, 2, 3, 4, 5, 6, 7, 8]

  ↓

[ [1, 2, 3], [4, 5, 6], [7, 8] ]

こんな形にするのを「バッチ化」と言うことがあります。これをイテレータを使って書く方法を考えてみます。

groupbyを使う方法

groupbyを使って、n個ずつグループ化して取り出す方法です。groupbyはキー関数が同じ値を返すものをグループ化するので、そのキー関数を「カウンタの値//バッチのサイズ」 (//は整数除算) とすればバッチ化になります。

    from itertools import count, groupby

    def batches(iterable, batchsize):
        iterable = iter(iterable)
        counter = count()
        for k,g in groupby(iterable, lambda x:counter.next()//batchsize):
            yield g

    for x in batches(xrange(1,32), 7):
        for i in x:
            print i,
        print

isliceを使う方法

これはActiveState Python Recipe
http://code.activestate.com/recipes/303279-getting-items-in-batches/
に書かれている方法で、isliceイテレータを使って「先頭からn個取り出す」を永久ループで繰り返す方法です。

    from itertools import islice, chain

    def batches(iterable, batchsize):
        iterable = iter(iterable)
        while True:
            batch = islice(iterable, batchsize)
            yield chain([batch.next()], batch)

    item_batch = batches(xrange(1,32), 7)
    for x in item_batch:
        for i in x:
            print i,
        print

シンプルですが1個所だけトリッキーな場所があります。本当はyield文で

        while True:
            batch = islice(iterable, batchsize)
            yield batch

のようにisliceイテレータをそのまま返せばいいはずです。でもこれは、途中まではうまく動くのですが、最後のループで空のisliceイテレータを返したとき、それ以降はiterableのイテレーションが一切発生しないため永久ループとなってしまいます。
そこで、isliceイテレータのnext()メソッドをわざと一回呼び出し、iterableのStopIteration例外を発生するようにしているのです。そして、一回next()で取り出した値をchainイテレータでくっつけなおして

    chain([0], iter([1, 2, 3, 4, 5, 6]))

という形で返しているのです。chainイテレータは全引数のイテレータから要素を取り出して順番に返すので、isliceイテレータを返すのと同様、0から6までのforループになります。

追記 2012-12-16

意外とシンプルなのに気づかなかった。仕様は少し違いますが、こちらの方が都合がいいかも。

>>> x = range(1, 32)
>>> map(None, *([iter(x)]*7))
[(1, 2, 3, 4, 5, 6, 7), (8, 9, 10, 11, 12, 13, 14), (15, 16, 17, 18,
 19, 20, 21), (22, 23, 24, 25, 26, 27, 28), (29, 30, 31, None, None,
 None, None)]