Python asyncio concurrent map pool function

Here is a relatively short Python function to map input items concurrently, with a fixed pool size to limit the maximum concurrency. The work function to map each item is a normal synchronous function, which makes this easy to use in projects that are not otherwise using asyncio.

This function should be OK to run with or without a pre-existing event loop in the process. It uses an existing event loop if there is one, otherwise it creates a new event loop and uses that. Either way, the calling code does not need to know that this function is using asyncio internally.

import asyncio
from typing import Callable, Iterable, List, Tuple, TypeVar

T = TypeVar("T")
U = TypeVar("U")

def asyncio_concurrent_map(items: Iterable[T], work: Callable[[T], U], pool_size: int = 10) -> Iterable[U]:
    if pool_size <= 0 or pool_size > 100:
        raise ValueError("pool_size %s outside of range 1-100" % pool_size)

    try:
        asyncio.get_event_loop()
    except RuntimeError:
        asyncio.set_event_loop(asyncio.new_event_loop())

    task_queue: asyncio.Queue[T] = asyncio.Queue()
    result_queue: asyncio.Queue[U] = asyncio.Queue()

    async def dispatcher():
        for item in items:
            await task_queue.put(item)

    async def worker():
        while not task_queue.empty():
            item: T = await task_queue.get()
            result: U = await asyncio.get_event_loop().run_in_executor(
                None,
                lambda: work(item),
            )
            task_queue.task_done()
            await result_queue.put(result)

    async def process_items():
        dispatch = asyncio.create_task(dispatcher())
        workers = [asyncio.create_task(worker()) for _ in range(pool_size)]

        await asyncio.gather(dispatch, *workers)

    asyncio.get_event_loop().run_until_complete(process_items())

    while not result_queue.empty():
        yield result_queue.get_nowait()

Here are some test cases with pytest and assertpy for the concurrent map pool function:

import asyncio
import string
import time
from time import sleep
from typing import Iterable

import pytest
from assertpy import assert_that


@pytest.mark.parametrize(
    "items, expected_mapped",
    [
        (
            [],
            [],
        ),
        (
            [1],
            [10],
        ),
        (
            [100, 200],
            [1000, 2000],
        ),
        (
            range(1000),
            [i * 10 for i in range(1000)],
        ),
    ],
)
def test__asyncio_concurrent_map__given_int_mapper(
    items: Iterable[int],
    expected_mapped: Iterable[int],
):
    mapped = list(
        asyncio_concurrent_map(
            items,
            lambda i: i * 10,
        )
    )
    assert_that(sorted(mapped)).is_equal_to(expected_mapped)


@pytest.mark.parametrize(
    "items, expected_mapped",
    [
        (
            [],
            [],
        ),
        (
            ["a"],
            ["A"],
        ),
        (
            string.ascii_lowercase,
            [c.upper() for c in string.ascii_lowercase],
        ),
    ],
)
def test__asyncio_concurrent_map__given_str_mapper(
    items: Iterable[str],
    expected_mapped: Iterable[str],
):
    mapped = list(asyncio_concurrent_map(items, lambda s: s.upper()))
    assert_that(sorted(mapped)).is_equal_to(expected_mapped)


@pytest.mark.parametrize(
    "items, expected_mapped",
    [
        (
            [],
            [],
        ),
        (
            ["a", "b"],
            [{"foo": "a"}, {"foo": "b"}],
        ),
    ],
)
def test__asyncio_concurrent_map__given_str_dict_mapper(
    items: Iterable[str],
    expected_mapped: Iterable[str],
):
    mapped = list(asyncio_concurrent_map(items, lambda s: {"foo": s}))
    assert_that(mapped).is_equal_to(expected_mapped)


@pytest.mark.parametrize(
    "pool_size",
    range(1, 100),
)
def test__asyncio_concurrent_map__given_pool_size(
    pool_size: int,
):
    mapped = list(
        asyncio_concurrent_map(
            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            lambda i: i * 2,
            pool_size,
        )
    )
    assert_that(sorted(mapped)).is_equal_to(sorted([0, 2, 4, 6, 8, 10, 12, 14, 16, 18]))


def test__asyncio_concurrent_map__invalid_pool_size__raises():
    with pytest.raises(ValueError, match="pool_size 0 outside of range 1-100"):
        list(
            asyncio_concurrent_map(
                [0, 1, 2, 3],
                lambda i: i,
                0,
            )
        )


def test__asyncio_concurrent_map__concurrency_sanity():
    """
    Sanity check that the tasks are being cleared concurrently.
    """
    single_task_time = 0.25
    start = time.perf_counter()
    list(asyncio_concurrent_map(range(8), lambda _: sleep(single_task_time), 4))

    seconds_taken = time.perf_counter() - start
    margin_seconds = 0.1
    allowed_seconds_taken = (single_task_time * 2) + margin_seconds
    assert_that(seconds_taken).is_less_than(allowed_seconds_taken)


def test__asyncio_concurrent_map__existing_event_loop():
    """
    Test that asyncio_concurrent_map can run with an existing event loop.
    """

    async def single_map(map_start_i: int):
        return await asyncio.get_event_loop().run_in_executor(
            None,
            lambda: asyncio_concurrent_map(range(map_start_i, map_start_i + 3), lambda item: item * 2),
        )

    async def multi_map():
        return await asyncio.gather(*[asyncio.create_task(single_map(start_i)) for start_i in range(3)])

    multi_map_result = asyncio.get_event_loop().run_until_complete(multi_map())

    assert_that([list(single_map_result) for single_map_result in multi_map_result]).is_equal_to(
        [
            [0, 2, 4],
            [2, 4, 6],
            [4, 6, 8],
        ]
    )

By the way, you can hire me as a freelance Python developer to work on your project.


Tech mentioned