With a given iterable whose elements are numbers, we could compute
the sum of squares of its elements using map
and
reduce
functions.
from functools import reduce
import operator
def sqrsum(nums):
return reduce(operator.add, map(lambda x: x**2, nums))
The benefit of using map
, reduce
instead of
a for
loop is that we can easily parallelize the program.
To implement the parallel in Python we could use
multiprocessing.Pool.map
to replace map
.
from functools import reduce
import operator
import multiprocessing
def square(n):
return n**2
def sqrsum2(nums):
with multiprocessing.Pool() as pl:
return reduce(operator.add, pl.map(square, nums))
multiprocessing.Pool
represents a process pool. The
default number of worker processes is given by the
os.cpu_count
function. We use it in a with
statement for proper management of resources.
After the pool pl
is created, we call its
map
method to make worker processes parallelly apply the
function to the iterable and returns a mapped iterable.
The Pool.map
method blocks until all of the work are
finished. This may substantially decrease the performance. The reduction
procedure only requires one element to compute but it must wait for all
elements. Thus we care more about the non-block version of
Pool.map
: Pool.imap
.
Pool.imap
is the same as Pool.map
but it
immediately returns a generator. The generator yields an element when
it’s required and ready. Let’s rewrite our snippet with
Pool.imap
.
from functools import reduce
import operator
import multiprocessing
def square(n):
return n**2
def sqrsum3(nums):
with multiprocessing.Pool() as pl:
return reduce(operator.add, pl.imap(square, nums))
We implement sqrsum3
in a very concise and efficient way
by the power of Pool.imap
. However, in many real-world
cases we can’t get such a concise implementation. One reason is that the
mapper function may raise exceptions. Let’s consider the following
snippet.
import multiprocessing
import time
def bar(n):
print('mapping', n)
time.sleep(1)
if n == 5:
raise ValueError()
return n
def foo():
with multiprocessing.Pool() as pl:
outputs = pl.imap(bar, range(100))
for i in outputs:
print('output', i)
foo()
The mapper raises an exception while n
equals to
5
. The loop in foo
will stop. However, some
subsequent tasks may already have been submitted and executed. Let’s
check the result on my machine.
mapping 0
mapping 1
mapping 2
mapping 3
mapping 4
mapping 5
mapping 6
mapping 7
mapping 8
mapping 9
mapping 10
mapping 11
mapping 12
output 0
output 1
mapping 13
output 2
output 3
mapping 14
output 4
mapping 15
The output stops at 4
but the tasks until
15
have been executed.
Thus if we need clean up things, we should be careful.
For example, we compute the sum of squares. The mapper generates a temporary file and writes the square to it then returns the path of the file. The reducer reads squares from temporary files and sums them up.
from functools import reduce
import multiprocessing
import tempfile
import os
def square(n):
fd, path = tempfile.mkstemp()
os.close(fd)
with open(path, mode='w') as fp:
fp.write(str(n**2))
return path
def sumup(part, current):
try:
with open(current) as fp:
return part + int(fp.read())
finally:
os.remove(current)
def sqrsum4(nums):
with multiprocessing.Pool() as pl:
return reduce(sumup, pl.imap(square, nums), 0)
print(sqrsum4(range(10)))
The sumup
function sums up the current square and
deletes the temporary file. If an exception is raised by the
square
function, e.g. fails to make a temporary file, the
reduction procedure will stop without cleaning up subsequent files.
A workaround is using a while
loop with the
next
function.
import multiprocessing
import tempfile
import os
def square(n):
if n == 3:
raise ValueError()
fd, path = tempfile.mkstemp()
os.close(fd)
print(f'mapping {n} to {path}', flush=True)
with open(path, mode='w') as fp:
fp.write(str(n**2))
return path
def sumup(part, current):
print(f'sum up {current}', flush=True)
with open(current) as fp:
return part + int(fp.read())
def cleanup(filepath):
print(f'clean up {filepath}', flush=True)
os.remove(filepath)
def sqrsum5(nums):
ret = 0
err = None
with multiprocessing.Pool() as pl:
sqrs = pl.imap(square, nums)
while True:
sqrfile = None
try:
sqrfile = next(sqrs)
except StopIteration:
break
except Exception as e:
if err is None:
err = e
if sqrfile is not None:
try:
if not err:
ret = sumup(ret, sqrfile)
finally:
cleanup(sqrfile)
if err:
raise err
return ret
print(sqrsum5(range(10)))
We move the cleaning up code from the reducer to an independent
function cleanup
and use a while
replacing the
reduce
function to handle exceptions and manage
resources.
The obvious disadvantage of the workaround is that we must finish the whole mapping even an exception happened at a very beginning time. Thus my personal opinion on the best practice is trying our best not to leak resources from the mapper. Clean up resources created by the mapper in the mapper as possible as we can.
Exceptions break concise implementations for it aborts the mapping
earlier. Signals can do the same thing and it’s even worse. If a worker
process is killed by a signal, Pool
may trap into a
deadlock. Maybe the issue can be fixed in a future version of Python.
However, at least it exists in Python 3.6.10 with macOS 10.15.5. The
following snippet reproduces the issue.
import multiprocessing
import signal
import os
def foo(x):
os.kill(os.getpid(), signal.SIGTERM)
return x
with multiprocessing.Pool() as pl:
outputs = pl.imap(foo, range(100))
for out in outputs:
print(out)
The issue has no generally reliable workaround, not to my knowledge. If anyone reading this text knows any solution and would like to leave a message, I will be very grateful.
For some subtle reasons, Pool.imap
needs to use
pickle
to serialize the mapper function but a lambda
function can’t be serialized by pickle
so we can’t use a
lambda expression as the mapper function.
""" Compute exponentials """
# This snippet doesn't work for it uses a lambda expression in `multiprocessing.Pool.imap`
import multiprocessing
with multiprocessing.Pool() as pl:
base = 2
outputs = pl.imap(lambda n: base**n, range(10))
for out in outputs:
print(out)
In fact, not only lambda expressions but also inner functions can’t
be used. pickle
just can’t serialize closures. Any
experienced developer could immediately realize how extremely it would
limit the application of Pool.imap
. How could a functional
programming tool doesn’t support closures?
A workaround is that we could simulate closures by partial functions.
functools.partial
binds arguments to a function and returns
a bound callable which can be serialized by pickle
. We
could use it to rewrite the above example.
import multiprocessing
import functools
def exp(base, n):
return base**n
with multiprocessing.Pool() as pl:
base = 2
bexp = functools.partial(exp, base)
outputs = pl.imap(bexp, range(10))
for out in outputs:
print(out)
Of course the workaround further breaks the concision.