Multi-process task queue using Redis Streams
In this post I'll present a short code snippet demonstrating how to use Redis streams to implement a multi-process task queue with Python. Task queues are commonly-used in web-based applications, as they allow decoupling time-consuming computation from the request/response cycle. For example when someone submits the "contact me" form, the webapp puts a message onto a task queue, so that the relatively time-consuming process of checking for spam and sending an email occurs outside the web request in a separate worker process.
The script I'll present is around 100 actual lines of code (excluding comments and whitespace), and provides a familiar API:
queue = TaskQueue('my-queue')
@queue.task
def fib(n):
a, b = 0, 1
for _ in range(n):
a, b = b, a + b
return b
# Calculate 100,000th fibonacci number in worker process.
fib100k = fib(100000)
# Block until the result becomes ready, then display last 6 digits.
print('100,000th fibonacci ends with: %s' % str(fib100k())[-6:])
When using Redis as a message broker, I've always favored using LPUSH/BRPOP (left-push, blocking right-pop) to enqueue and dequeue a message. Pushing items onto a list ensures that messages will not be lost if the queue is growing faster than it can be processed – messages just get added until the consumer(s) catch up. Blocking right-pop is an atomic operation, so Redis also guarantees that no matter how many consumers you've got listening for messages, each message is delivered to only one consumer.
There are some downsides to using lists, primarily the fact that blocking right-pop is a destructive read. Once a message is read, the application can no longer tell whether the message was processed successfully or has failed and needs to be retried. Similarly, there is no visibility into which consumer processed a given message.
Redis 5.0 includes a new streams data-type for modelling append-only, persistent message logging. Streams are identified by a key, like other data-types, and support append, read and delete operations. Streams provide a number of benefits over other data-types typically used for building distributed task queues using Redis, particularly when used with consumer groups.
- Streams support fan-out message delivery to all interested readers (kinda like pub/sub), or you can use consumer groups to ensure that messages are distributed evenly among a pool of consumers (like lpush/brpop).
- Messages are persistent and history is kept around, even after a message has been read by a consumer.
- Message delivery information is tracked by Redis, making it easy to identify which tasks were completed successfully, and which failed and need to be retried (at the cost of an explicit ACK).
- Messages are structured as any number of arbitrary key/value pairs, providing a bit more internal structure than an opaque blob stored in a list.
Consumer groups provide us with a unified interface for managing message delivery and querying the status of the task queue. These features make Redis a nice option if you need a message broker.
Building a simple task queue
To show how easy it is to build a task queue on top of Redis streams, we'll
implement a simple multi-process task queue using Python. Our task queue will
support execution of arbitrary Python functions among a pool of worker
processes using the familiar @task
decorator.
The code revolves around three simple classes:
TaskQueue
– responsible for writing messages to the stream and coordinates a pool of worker processes. Also provides interfaces for reading and writing task results.TaskWorker
– responsible for reading messages from the stream and executing the tasks.TaskResultWrapper
– allows the caller of a task-decorated function to retrieve the result of the task.
If you look closely at the code, you'll see it is all thin Python layer on top of just a handful of Redis calls:
add()
method is used to write a message to the stream when a task is invoked.read()
method of the consumer group is used to read one message at a time from the stream.ack()
method is used by the worker to indicate the successful execution of a task.hget
andhset
are used to store task return values in a Redis hash.
Here's the code, which is thoroughly-commented and sits at around 100 source lines of code:
from collections import namedtuple
from functools import wraps
import datetime
import multiprocessing
import pickle
import time
# At the time of writing, the standard redis-py client does not implement
# stream/consumer-group commands. We'll use "walrus", which extends the client
# from redis-py to provide stream support and high-level, Pythonic containers.
# More info: https://github.com/coleifer/walrus
from walrus import Walrus
# Lightweight wrapper for storing exceptions that occurred executing a task.
TaskError = namedtuple('TaskError', ('error',))
class TaskQueue(object):
def __init__(self, client, stream_key='tasks'):
self.client = client # our Redis client.
self.stream_key = stream_key
# We'll also create a consumer group (whose name is derived from the
# stream key). Consumer groups are needed to provide message delivery
# tracking and to ensure that our messages are distributed among the
# worker processes.
self.name = stream_key + '-cg'
self.consumer_group = self.client.consumer_group(self.name, stream_key)
self.result_key = stream_key + '.results' # Store results in a Hash.
# Obtain a reference to the stream within the context of the
# consumer group.
self.stream = getattr(self.consumer_group, stream_key)
self.signal = multiprocessing.Event() # Used to signal shutdown.
self.signal.set() # Indicate the server is not running.
# Create the stream and consumer group (if they do not exist).
self.consumer_group.create()
self._running = False
self._tasks = {} # Lookup table for mapping function name -> impl.
def task(self, fn):
self._tasks[fn.__name__] = fn # Store function in lookup table.
@wraps(fn)
def inner(*args, **kwargs):
# When the decorated function is called, a message is added to the
# stream and a wrapper class is returned, which provides access to
# the task result.
message = self.serialize_message(fn, args, kwargs)
# Our message format is very simple -- just a "task" key and a blob
# of pickled data. You could extend this to provide additional
# data, such as the source of the event, etc, etc.
task_id = self.stream.add({'task': message})
return TaskResultWrapper(self, task_id)
return inner
def deserialize_message(self, message):
task_name, args, kwargs = pickle.loads(message)
if task_name not in self._tasks:
raise Exception('task "%s" not registered with queue.')
return self._tasks[task_name], args, kwargs
def serialize_message(self, task, args=None, kwargs=None):
return pickle.dumps((task.__name__, args, kwargs))
def store_result(self, task_id, result):
# API for storing the return value from a task. This is called by the
# workers after the execution of a task.
if result is not None:
self.client.hset(self.result_key, task_id, pickle.dumps(result))
def get_result(self, task_id):
# Obtain the return value of a finished task. This API is used by the
# TaskResultWrapper class. We'll use a pipeline to ensure that reading
# and popping the result is an atomic operation.
pipe = self.client.pipeline()
pipe.hexists(self.result_key, task_id)
pipe.hget(self.result_key, task_id)
pipe.hdel(self.result_key, task_id)
exists, val, n = pipe.execute()
return pickle.loads(val) if exists else None
def run(self, nworkers=1):
if not self.signal.is_set():
raise Exception('workers are already running')
# Start a pool of worker processes.
self._pool = []
self.signal.clear()
for i in range(nworkers):
worker = TaskWorker(self)
worker_t = multiprocessing.Process(target=worker.run)
worker_t.start()
self._pool.append(worker_t)
def shutdown(self):
if self.signal.is_set():
raise Exception('workers are not running')
# Send the "shutdown" signal and wait for the worker processes
# to exit.
self.signal.set()
for worker_t in self._pool:
worker_t.join()
class TaskWorker(object):
_worker_idx = 0
def __init__(self, queue):
self.queue = queue
self.consumer_group = queue.consumer_group
# Assign each worker processes a unique name.
TaskWorker._worker_idx += 1
worker_name = 'worker-%s' % TaskWorker._worker_idx
self.worker_name = worker_name
def run(self):
while not self.queue.signal.is_set():
# Read up to one message, blocking for up to 1sec, and identifying
# ourselves using our "worker name".
resp = self.consumer_group.read(1, 1000, self.worker_name)
if resp is not None:
# Resp is structured as:
# {stream_key: [(message id, data), ...]}
for stream_key, message_list in resp:
task_id, data = message_list[0]
self.execute(task_id.decode('utf-8'), data[b'task'])
def execute(self, task_id, message):
# Deserialize the task message, which consists of the task name, args
# and kwargs. The task function is then looked-up by name and called
# using the given arguments.
task, args, kwargs = self.queue.deserialize_message(message)
try:
ret = task(*(args or ()), **(kwargs or {}))
except Exception as exc:
# On failure, we'll store a special "TaskError" as the result. This
# will signal to the user that the task failed with an exception.
self.queue.store_result(task_id, TaskError(str(exc)))
else:
# Store the result and acknowledge (ACK) the message.
self.queue.store_result(task_id, ret)
self.queue.stream.ack(task_id)
class TaskResultWrapper(object):
def __init__(self, queue, task_id):
self.queue = queue
self.task_id = task_id
self._result = None
def __call__(self, block=True, timeout=None):
if self._result is None:
# Get the result from the result-store, optionally blocking until
# the result becomes available.
if not block:
result = self.queue.get_result(self.task_id)
else:
start = time.time()
while timeout is None or (start + timeout) > time.time():
result = self.queue.get_result(self.task_id)
if result is None:
time.sleep(0.1)
else:
break
if result is not None:
self._result = result
if self._result is not None and isinstance(self._result, TaskError):
raise Exception('task failed: %s' % self._result.error)
return self._result
Let's look at a sample script utilizing this code. We'll define two tasks -- one that just sleeps for a given number of seconds, and another that is CPU-bound and computes the nth fibonacci number.
db = Walrus() # roughly equivalent to db = Redis().
queue = TaskQueue(db)
@queue.task
def sleep(n):
print('going to sleep for %s seconds' % n)
time.sleep(n)
print('woke up after %s seconds' % n)
@queue.task
def fib(n):
a, b = 0, 1
for _ in range(n):
a, b = b, a + b
return b
# Start the queue with four worker processes.
queue.run(4)
# Send four "sleep" tasks.
sleep(2)
sleep(3)
sleep(4)
sleep(5)
# Send four tasks to compute large fibonacci numbers. We will then print the
# last 6 digits of each computed number (showing how result storage works):
v100k = fib(100000)
v200k = fib(200000)
v300k = fib(300000)
v400k = fib(400000)
# Calling the result wrapper will block until its value becomes available:
print('100kth fib number starts ends with: %s' % str(v100k())[-6:])
print('200kth fib number starts ends with: %s' % str(v200k())[-6:])
print('300kth fib number starts ends with: %s' % str(v300k())[-6:])
print('400kth fib number starts ends with: %s' % str(v400k())[-6:])
# We can shutdown and restart the consumer.
queue.shutdown()
print('all workers have stopped.')
queue.run(4)
print('workers are running again.')
# Enqueue another "sleep" task.
sleep(2)
# Calling shutdown now will block until the above sleep task
# has finished, after which all workers will stop.
queue.shutdown()
print('done!')
On my machine, this short script produces the following output:
going to sleep for 3 seconds
going to sleep for 2 seconds
going to sleep for 4 seconds
going to sleep for 5 seconds
woke up after 2 seconds
100kth fib number starts ends with: 537501
200kth fib number starts ends with: 590626
woke up after 3 seconds
300kth fib number starts ends with: 800001
woke up after 4 seconds
woke up after 5 seconds
400kth fib number starts ends with: 337501
all workers have stopped.
workers are running again.
going to sleep for 2 seconds
woke up after 2 seconds
done!
By observing the script run, its easy to note that the tasks are being executed in parallel as the entire script takes just 8 seconds to finish!
Enhancing this design
By virtue of simply using Redis streams for the message broker, this simple example benefits from the fact that message history is preserved and can be introspected at any time. A number of useful features provided by streams are not being utilized, however:
- Pending (read, but unacknowledged) messages can be introspected and claimed by another consumer. This could be used to retry tasks that have failed, or to track long-running tasks. Documentation for pending() method.
- Messages can contain arbitrary key/value data, allowing each task message to contain additional application-specific metadata.
- Consumer groups can be configured to read from multiple streams, for example a "high-priority" stream that is low-traffic, and a "general purpose" stream for everything else.
- When adding messages to a stream, you can specify an approximate upper-bound on the length of the stream, allowing you to keep the stream from growing endlessly. To do this, specify a
maxlen
in the call to add(). - Because streams are persistent, it should be easy to use the
range()
andpending()
functionality to implement a dashboard for viewing the status of the queue. - How would you handle tasks that should be executed in the future? Hint: you can use Redis sorted-sets and specify the task's execution timestamp as the "score" (this requires setting up a separate process to manage reading from the schedule).
- Allow the task queue and worker management to be run standalone.
What functionality would you add?
Links
- Introduction to redis streams – Redis documentation
- Using Redis streams with Python – post I wrote that provides an overview of the streams APIs.
- Walrus, lightweight Python utilities for working with Redis – on GitHub.
- huey, a lightweight task queue – task queue that uses the lpush/brpop pattern.
Comments (0)
Commenting has been closed.