Write your own miniature Redis with Python

photos/redis-like-server.png

The other day the idea occurred to me that it would be neat to write a simple Redis-like database server. While I've had plenty of experience with WSGI applications, a database server presented a novel challenge and proved to be a nice practical way of learning how to work with sockets in Python. In this post I'll share what I learned along the way.

The goal of my project was to write a simple server that I could use with a task queue project of mine called huey. Huey uses Redis as the default storage engine for tracking enqueued jobs, results of finished jobs, and other things. For the purposes of this post, I've reduced the scope of the original project even further so as not to muddy the waters with code you could very easily write yourself, but if you're curious, you can check out the end result here (documentation).

The server we'll be building will be able to respond to the following commands:

We'll support the following data-types as well:

To handle multiple clients asynchronously, we'll be using gevent, but you could also use the standard library's SocketServer module with either the ForkingMixin or the ThreadingMixin.

Skeleton

Let's frame up a skeleton for our server. We'll need the server itself, and a callback to be executed when a new client connects. Additionally we'll need some kind of logic to process the client request and to send a response.

Here's a start:

from gevent import socket
from gevent.pool import Pool
from gevent.server import StreamServer

from collections import namedtuple
from io import BytesIO
from socket import error as socket_error


# We'll use exceptions to notify the connection-handling loop of problems.
class CommandError(Exception): pass
class Disconnect(Exception): pass

Error = namedtuple('Error', ('message',))


class ProtocolHandler(object):
    def handle_request(self, socket_file):
        # Parse a request from the client into it's component parts.
        pass

    def write_response(self, socket_file, data):
        # Serialize the response data and send it to the client.
        pass


class Server(object):
    def __init__(self, host='127.0.0.1', port=31337, max_clients=64):
        self._pool = Pool(max_clients)
        self._server = StreamServer(
            (host, port),
            self.connection_handler,
            spawn=self._pool)

        self._protocol = ProtocolHandler()
        self._kv = {}

    def connection_handler(self, conn, address):
        # Convert "conn" (a socket object) into a file-like object.
        socket_file = conn.makefile('rwb')

        # Process client requests until client disconnects.
        while True:
            try:
                data = self._protocol.handle_request(socket_file)
            except Disconnect:
                break

            try:
                resp = self.get_response(data)
            except CommandError as exc:
                resp = Error(exc.args[0])

            self._protocol.write_response(socket_file, resp)

    def get_response(self, data):
        # Here we'll actually unpack the data sent by the client, execute the
        # command they specified, and pass back the return value.
        pass

    def run(self):
        self._server.serve_forever()

The above code is hopefully fairly clear. We've separated concerns so that the protocol handling is in it's own class with two public methods: handle_request and write_response. The server itself uses the protocol handler to unpack client requests and serialize server responses back to the client. The get_response() method will be used to execute the command initiatied by the client.

Taking a closer look at the code of the connection_handler() method, you can see that we obtain a file-like wrapper around the socket object. This wrapper allows us to abstract away some of the quirks one typically encounters working with raw sockets. The function enters an endless loop, reading requests from the client, sending responses, and finally exiting the loop when the client disconnects (indicated by read() returning an empty string).

We use typed exceptions to handle client disconnects and to notify the user of errors processing commands. For example, if the user makes an improperly formatted request to the server, we will raise a CommandError, which is serialized into an error response and sent to the client.

Before going further, let's discuss how the client and server will communicate.

Wire protocol

The first challenge I faced was how to handle sending binary data over the wire. Most examples I found online were pointless echo servers that converted the socket to a file-like object and just called readline(). If I wanted to store some pickled data or strings with new-lines, I would need to have some kind of serialization format.

After wasting time trying to invent something suitable, I decided to read the documentation on the Redis protocol, which turned out to be very simple to implement and has the added benefit of supporting a couple different data-types.

The Redis protocol uses a request/response communication pattern with the clients. Responses from the server will use the first byte to indicate data-type, followed by the data, terminated by a carriage-return/line-feed.

Data-type Prefix Structure Example
Simple string + +{string data}\r\n
+this is a simple string\r\n
Error -{error message}\r\n
-ERR unknown command "FLUHS"\r\n
Integer : :{the number}\r\n
:1337\r\n
Binary $ ${number of bytes}\r\n{data}\r\n
$6\r\n
foobar\r\n
Array * *{number of elements}\r\n{0 or more of above}\r\n
*3\r\n
+a simple string element\r\n
:12345\r\n
$7\r\n
testing\r\n
Dictionary % %{number of keys}\r\n{0 or more of above}\r\n
%3\r\n
+key1\r\n
+value1\r\n
+key2\r\n
*2\r\n
+value2-0\r\n
+value2-1\r\n
:3\r\n
$7\r\n
testing\r\n
NULL $ $-1\r\n (string of length -1)
$-1\r\n

Let's fill in the protocol handler's class so that it implements the Redis protocol.

class ProtocolHandler(object):
    def __init__(self):
        self.handlers = {
            '+': self.handle_simple_string,
            '-': self.handle_error,
            ':': self.handle_integer,
            '$': self.handle_string,
            '*': self.handle_array,
            '%': self.handle_dict}

    def handle_request(self, socket_file):
        first_byte = socket_file.read(1)
        if not first_byte:
            raise Disconnect()

        try:
            # Delegate to the appropriate handler based on the first byte.
            return self.handlers[first_byte](socket_file)
        except KeyError:
            raise CommandError('bad request')

    def handle_simple_string(self, socket_file):
        return socket_file.readline().rstrip('\r\n')

    def handle_error(self, socket_file):
        return Error(socket_file.readline().rstrip('\r\n'))

    def handle_integer(self, socket_file):
        return int(socket_file.readline().rstrip('\r\n'))

    def handle_string(self, socket_file):
        # First read the length ($<length>\r\n).
        length = int(socket_file.readline().rstrip('\r\n'))
        if length == -1:
            return None  # Special-case for NULLs.
        length += 2  # Include the trailing \r\n in count.
        return socket_file.read(length)[:-2]

    def handle_array(self, socket_file):
        num_elements = int(socket_file.readline().rstrip('\r\n'))
        return [self.handle_request(socket_file) for _ in range(num_elements)]

    def handle_dict(self, socket_file):
        num_items = int(socket_file.readline().rstrip('\r\n'))
        elements = [self.handle_request(socket_file)
                    for _ in range(num_items * 2)]
        return dict(zip(elements[::2], elements[1::2]))

For the serialization side of the protocol, we'll do the opposite of the above: turn Python objects into their serialized counterparts!

class ProtocolHandler(object):
    # ... above methods omitted ...
    def write_response(self, socket_file, data):
        buf = BytesIO()
        self._write(buf, data)
        buf.seek(0)
        socket_file.write(buf.getvalue())
        socket_file.flush()

    def _write(self, buf, data):
        if isinstance(data, str):
            data = data.encode('utf-8')

        if isinstance(data, bytes):
            buf.write('$%s\r\n%s\r\n' % (len(data), data))
        elif isinstance(data, int):
            buf.write(':%s\r\n' % data)
        elif isinstance(data, Error):
            buf.write('-%s\r\n' % error.message)
        elif isinstance(data, (list, tuple)):
            buf.write('*%s\r\n' % len(data))
            for item in data:
                self._write(buf, item)
        elif isinstance(data, dict):
            buf.write('%%%s\r\n' % len(data))
            for key in data:
                self._write(buf, key)
                self._write(buf, data[key])
        elif data is None:
            buf.write('$-1\r\n')
        else:
            raise CommandError('unrecognized type: %s' % type(data))

An additional benefit of keeping the protocol handling in its own class is that we can re-use the handle_request and write_response methods to build a client library.

Implementing Commands

The Server class we mocked up now needs to have it's get_response() method implemented. Commands will be assumed to be sent by the client as either simple strings or an array of command arguments, so the data parameter passed to get_response() will either be bytes or a list. To simplify handling, if data is a simple string, we'll convert it to a list by splitting on whitespace.

The first argument will be the command name, with any additional arguments belonging to the specified command. As we did with the mapping of the first byte to the handlers in the ProtocolHandler, let's create a mapping of command to callback in the Server:

class Server(object):
    def __init__(self, host='127.0.0.1', port=31337, max_clients=64):
        self._pool = Pool(max_clients)
        self._server = StreamServer(
            (host, port),
            self.connection_handler,
            spawn=self._pool)

        self._protocol = ProtocolHandler()
        self._kv = {}

        self._commands = self.get_commands()

    def get_commands(self):
        return {
            'GET': self.get,
            'SET': self.set,
            'DELETE': self.delete,
            'FLUSH': self.flush,
            'MGET': self.mget,
            'MSET': self.mset}

    def get_response(self, data):
        if not isinstance(data, list):
            try:
                data = data.split()
            except:
                raise CommandError('Request must be list or simple string.')

        if not data:
            raise CommandError('Missing command')

        command = data[0].upper()
        if command not in self._commands:
            raise CommandError('Unrecognized command: %s' % command)

        return self._commands[command](*data[1:])

Our server is almost finished! We just need to implement the six command methods defined in the get_commands() method:

class Server(object):
    def get(self, key):
        return self._kv.get(key)

    def set(self, key, value):
        self._kv[key] = value
        return 1

    def delete(self, key):
        if key in self._kv:
            del self._kv[key]
            return 1
        return 0

    def flush(self):
        kvlen = len(self._kv)
        self._kv.clear()
        return kvlen

    def mget(self, *keys):
        return [self._kv.get(key) for key in keys]

    def mset(self, *items):
        data = zip(items[::2], items[1::2])
        for key, value in data:
            self._kv[key] = value
        return len(data)

That's it! Our server is now ready to start processing requests. In the next section we'll implement a client to interact with the server.

Client

To interact with the server, let's re-use the ProtocolHandler class to implement a simple client. The client will connect to the server and send commands encoded as lists. We'll re-use both the write_response() and the handle_request() logic for encoding requests and processing server responses respectively.

class Client(object):
    def __init__(self, host='127.0.0.1', port=31337):
        self._protocol = ProtocolHandler()
        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._socket.connect((host, port))
        self._fh = self._socket.makefile('rwb')

    def execute(self, *args):
        self._protocol.write_response(self._fh, args)
        resp = self._protocol.handle_request(self._fh)
        if isinstance(resp, Error):
            raise CommandError(resp.message)
        return resp

With the execute() method, we can pass an arbitrary list of parameters which will be encoded as an array and sent to the server. The response from the server is parsed and returned as a Python object. For convenience, we can write client methods for the individual commands:

class Client(object):
    # ...
    def get(self, key):
        return self.execute('GET', key)

    def set(self, key, value):
        return self.execute('SET', key, value)

    def delete(self, key):
        return self.execute('DELETE', key)

    def flush(self):
        return self.execute('FLUSH')

    def mget(self, *keys):
        return self.execute('MGET', *keys)

    def mset(self, *items):
        return self.execute('MSET', *items)

To test out our client, let's configure our Python script to start up a server when executed directly from the command-line:

# Add this to bottom of module:
if __name__ == '__main__':
    from gevent import monkey; monkey.patch_all()
    Server().run()

Testing the Server

To test the server, just execute the server's Python module from the command line. In another terminal, open up a Python interpreter and import the Client class from the server's module. Instantiating the client will open a connection and you can start running commands!

>>> from server_ex import Client
>>> client = Client()
>>> client.mset('k1', 'v1', 'k2', ['v2-0', 1, 'v2-2'], 'k3', 'v3')
3
>>> client.get('k2')
['v2-0', 1, 'v2-2']
>>> client.mget('k3', 'k1')
['v3', 'v1']
>>> client.delete('k1')
1
>>> client.get('k1')
>>> client.delete('k1')
0
>>> client.set('kx', {'vx': {'vy': 0, 'vz': [1, 2, 3]}})
1
>>> client.get('kx')
{'vx': {'vy': 0, 'vz': [1, 2, 3]}}
>>> client.flush()
2

The code presented in this post is absolutely for demonstration purposes only. I hope you enjoyed reading about this project as much as I enjoyed writing about it. You can find a complete copy of the code here. To see a more complete example, check out simpledb (documentation).

To extend the project, you might consider:

Comments (3)

Sami | jan 05 2018, at 12:17am

If you use CLOCK-Pro algorithm with (or instead) dictionary. Then you'll get cache functionality, which will evict least needed data on overflow. https://bitbucket.org/SamiLehtinen/pyclockpro

Charles | jan 03 2018, at 12:48pm

As written all data is stored in memory (in ._kv), though a suggested exercise is to re-use the protocol handler to write an append-only log so you could thereby reconstruct the in-memory data-set by replaying the log.

Marlysson | jan 03 2018, at 06:52am

Awesome idea, good job man, It handle with some type of either cache or "just" works similar to dict in memory?


Commenting has been closed.