April 03, 2024

Filter batched zip(*product(…))

In the previous post I wrote about how you can iterate through zip(product(*iterables)) in batches.

To refresh our memory, the function looks like this:

from itertools import product

def zip_product_in_batches(*iterables, size=1000):
    """
    Calculates products for the provided 'iterables' and returns the 'size' elements at a time.
    """
    if type(size) != int or size < 1:
        raise ValueError(f"'size' param needs to be a positive Integer number. Got: {size} ({type(size)}")
    batch = []
    for elements in product(*iterables):
        batch.append(elements)
        if len(batch) >= size:
            yield zip(*batch)
            batch = []
    
    yield zip(*batch)

But sometimes it’s useful to also filter the results and ignore the values we don’t need.

If we would have fixed number of iterables we could just add the elements to the list when the rule is true.

from itertools import product

def filtered_zip_product_in_batches(list_1, list_2, list_3, size=1000):
    """
    Calculates products for the provided 'iterables' and returns the 'size' elements at a time.
    """
    if type(size) != int or size < 1:
        raise ValueError(f"'size' param needs to be a positive Integer number. Got: {size} ({type(size)}")
    batch = []
    for element_1, element_2, element_3 in product(list_1, list_2, list_3):
        if not element_1 < element_2 < element_3:
            continue
        batch.append([element_1, element_2, element_3])
        if len(batch) >= size:
            yield zip(*batch)
            batch = []
    
    yield zip(*batch)

But in our original function, we allow user to use any number of iterables. To solve this we will loop through the list of elements before we add them to the batch:

from itertools import product

def filtered_zip_product_in_batches(*iterables, size=1000):
    """
    Calculates products for the provided 'iterables' and returns the 'size' elements at a time.
    """
    if type(size) != int or size < 1:
        raise ValueError(f"'size' param needs to be a positive Integer number. Got: {size} ({type(size)}")
    batch = []
    len_iterables = len(iterables)
    for elements in product(*iterables):
        if all([elements[i] < elements[i+1] for i in range(len_iterables-1)]):
            batch.append(elements)
            if len(batch) >= size:
                yield zip(*batch)
                batch = []
    
    yield zip(*batch)

In line 12 we loop through all elements, taking 2 at a time. If the 1st element is lower than the 2nd element, the result is True, otherwise it’s False. The function all(iterable) returns True if all elements of iterable are true.

 

Lets check it out in practice:

import numpy as np

range_1 = np.arange(1, 10, step=1, dtype=int)
range_2 = np.arange(1, 10, step=1, dtype=int)
range_3 = np.arange(1, 10, step=1, dtype=int)

for list_1, list_2, list_3 in filtered_zip_product_in_batches(range_1, range_2, range_3, size=10):
    print("list_1:", list_1)
    print("list_2:", list_2)
    print("list_3:", list_3)
    print("-" * 25)

Which will print:

list_1: (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
list_2: (2, 2, 2, 2, 2, 2, 2, 3, 3, 3)
list_3: (3, 4, 5, 6, 7, 8, 9, 4, 5, 6)
-------------------------
list_1: (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
list_2: (3, 3, 3, 4, 4, 4, 4, 4, 5, 5)
list_3: (7, 8, 9, 5, 6, 7, 8, 9, 6, 7)
-------------------------
list_1: (1, 1, 1, 1, 1, 1, 1, 1, 2, 2)
list_2: (5, 5, 6, 6, 6, 7, 7, 8, 3, 3)
list_3: (8, 9, 7, 8, 9, 8, 9, 9, 4, 5)
-------------------------
list_1: (2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
list_2: (3, 3, 3, 3, 4, 4, 4, 4, 4, 5)
list_3: (6, 7, 8, 9, 5, 6, 7, 8, 9, 6)
-------------------------
list_1: (2, 2, 2, 2, 2, 2, 2, 2, 2, 3)
list_2: (5, 5, 5, 6, 6, 6, 7, 7, 8, 4)
list_3: (7, 8, 9, 7, 8, 9, 8, 9, 9, 5)
-------------------------
list_1: (3, 3, 3, 3, 3, 3, 3, 3, 3, 3)
list_2: (4, 4, 4, 4, 5, 5, 5, 5, 6, 6)
list_3: (6, 7, 8, 9, 6, 7, 8, 9, 7, 8)
-------------------------
list_1: (3, 3, 3, 3, 4, 4, 4, 4, 4, 4)
list_2: (6, 7, 7, 8, 5, 5, 5, 5, 6, 6)
list_3: (9, 8, 9, 9, 6, 7, 8, 9, 7, 8)
-------------------------
list_1: (4, 4, 4, 4, 5, 5, 5, 5, 5, 5)
list_2: (6, 7, 7, 8, 6, 6, 6, 7, 7, 8)
list_3: (9, 8, 9, 9, 7, 8, 9, 8, 9, 9)
-------------------------
list_1: (6, 6, 6, 7)
list_2: (7, 7, 8, 8)
list_3: (8, 9, 9, 9)
-------------------------

In this case, the zip_product_in_batches() would return 729 values, yet the filtered_zip_product_in_batches() returned only 84 of them.

 

5 1 vote
Article Rating

Share the knowledge on

Do you like my content? You can support me at
Buy Me A Coffee
Subscribe
Notify of
0 Comments
Inline Feedbacks
View all comments