The Practice of Pool.imap in Python

DONG Yuxuan https://www.dyx.name

16 Sep 2020 (+0800)

Introduction

With a given iterable whose elements are numbers, we could compute the sum of squares of its elements using map and reduce functions.

from functools import reduce
import operator

def sqrsum(nums):
        return reduce(operator.add, map(lambda x: x**2, nums))

The benefit of using map, reduce instead of a for loop is that we can easily parallelize the program. To implement the parallel in Python we could use multiprocessing.Pool.map to replace map.

from functools import reduce
import operator
import multiprocessing

def square(n):
        return n**2

def sqrsum2(nums):
        with multiprocessing.Pool() as pl:
                return reduce(operator.add, pl.map(square, nums))

multiprocessing.Pool represents a process pool. The default number of worker processes is given by the os.cpu_count function. We use it in a with statement for proper management of resources.

After the pool pl is created, we call its map method to make worker processes parallelly apply the function to the iterable and returns a mapped iterable.

The Pool.map method blocks until all of the work are finished. This may substantially decrease the performance. The reduction procedure only requires one element to compute but it must wait for all elements. Thus we care more about the non-block version of Pool.map: Pool.imap.

Pool.imap is the same as Pool.map but it immediately returns a generator. The generator yields an element when it’s required and ready. Let’s rewrite our snippet with Pool.imap.

from functools import reduce
import operator
import multiprocessing

def square(n):
        return n**2

def sqrsum3(nums):
        with multiprocessing.Pool() as pl:
                return reduce(operator.add, pl.imap(square, nums))

Exceptions

We implement sqrsum3 in a very concise and efficient way by the power of Pool.imap. However, in many real-world cases we can’t get such a concise implementation. One reason is that the mapper function may raise exceptions. Let’s consider the following snippet.

import multiprocessing
import time

def bar(n):
        print('mapping', n)
        time.sleep(1)
        if n == 5:
                raise ValueError()
        return n

def foo():
        with multiprocessing.Pool() as pl:
                outputs = pl.imap(bar, range(100))
                for i in outputs:
                        print('output', i)

foo()

The mapper raises an exception while n equals to 5. The loop in foo will stop. However, some subsequent tasks may already have been submitted and executed. Let’s check the result on my machine.

mapping 0
mapping 1
mapping 2
mapping 3
mapping 4
mapping 5
mapping 6
mapping 7
mapping 8
mapping 9
mapping 10
mapping 11
mapping 12
output 0
output 1
mapping 13
output 2
output 3
mapping 14
output 4
mapping 15

The output stops at 4 but the tasks until 15 have been executed.

Thus if we need clean up things, we should be careful.

For example, we compute the sum of squares. The mapper generates a temporary file and writes the square to it then returns the path of the file. The reducer reads squares from temporary files and sums them up.

from functools import reduce
import multiprocessing
import tempfile
import os

def square(n):
        fd, path = tempfile.mkstemp()
        os.close(fd)

        with open(path, mode='w') as fp:
                fp.write(str(n**2))

        return path

def sumup(part, current):
        try:
                with open(current) as fp:
                        return part + int(fp.read())
        finally:
                os.remove(current)

def sqrsum4(nums):
        with multiprocessing.Pool() as pl:
                return reduce(sumup, pl.imap(square, nums), 0)

print(sqrsum4(range(10)))

The sumup function sums up the current square and deletes the temporary file. If an exception is raised by the square function, e.g. fails to make a temporary file, the reduction procedure will stop without cleaning up subsequent files.

A workaround is using a while loop with the next function.

import multiprocessing
import tempfile
import os

def square(n):
        if n == 3:
                raise ValueError()

        fd, path = tempfile.mkstemp()
        os.close(fd)

        print(f'mapping {n} to {path}', flush=True)

        with open(path, mode='w') as fp:
                fp.write(str(n**2))

        return path

def sumup(part, current):
        print(f'sum up {current}', flush=True)

        with open(current) as fp:
                return part + int(fp.read())

def cleanup(filepath):
        print(f'clean up {filepath}', flush=True)

        os.remove(filepath)

def sqrsum5(nums):
        ret = 0
        err = None

        with multiprocessing.Pool() as pl:
                sqrs = pl.imap(square, nums)

                while True:
                        sqrfile = None

                        try:
                                sqrfile = next(sqrs)
                        except StopIteration:
                                break
                        except Exception as e:
                                if err is None:
                                        err = e

                        if sqrfile is not None:
                                try:
                                        if not err:
                                                ret = sumup(ret, sqrfile)
                                finally:
                                        cleanup(sqrfile)

        if err:
                raise err

        return ret

print(sqrsum5(range(10)))

We move the cleaning up code from the reducer to an independent function cleanup and use a while replacing the reduce function to handle exceptions and manage resources.

The obvious disadvantage of the workaround is that we must finish the whole mapping even an exception happened at a very beginning time. Thus my personal opinion on the best practice is trying our best not to leak resources from the mapper. Clean up resources created by the mapper in the mapper as possible as we can.

Signals

Exceptions break concise implementations for it aborts the mapping earlier. Signals can do the same thing and it’s even worse. If a worker process is killed by a signal, Pool may trap into a deadlock. Maybe the issue can be fixed in a future version of Python. However, at least it exists in Python 3.6.10 with macOS 10.15.5. The following snippet reproduces the issue.

import multiprocessing
import signal
import os

def foo(x):
        os.kill(os.getpid(), signal.SIGTERM)
        return x

with multiprocessing.Pool() as pl:
        outputs = pl.imap(foo, range(100))
        for out in outputs:
                print(out)

The issue has no generally reliable workaround, not to my knowledge. If anyone reading this text knows any solution and would like to leave a message, I will be very grateful.

Closures

For some subtle reasons, Pool.imap needs to use pickle to serialize the mapper function but a lambda function can’t be serialized by pickle so we can’t use a lambda expression as the mapper function.

""" Compute exponentials """
# This snippet doesn't work for it uses a lambda expression in `multiprocessing.Pool.imap`

import multiprocessing

with multiprocessing.Pool() as pl:
        base = 2
        outputs = pl.imap(lambda n: base**n, range(10))
        for out in outputs:
                print(out)

In fact, not only lambda expressions but also inner functions can’t be used. pickle just can’t serialize closures. Any experienced developer could immediately realize how extremely it would limit the application of Pool.imap. How could a functional programming tool doesn’t support closures?

A workaround is that we could simulate closures by partial functions. functools.partial binds arguments to a function and returns a bound callable which can be serialized by pickle. We could use it to rewrite the above example.

import multiprocessing
import functools

def exp(base, n):
        return base**n

with multiprocessing.Pool() as pl:
        base = 2
        bexp = functools.partial(exp, base)
        outputs = pl.imap(bexp, range(10))
        for out in outputs:
                print(out)

Of course the workaround further breaks the concision.