Write your own miniature Redis with Python
The other day the idea occurred to me that it would be neat to write a simple Redis-like database server. While I've had plenty of experience with WSGI applications, a database server presented a novel challenge and proved to be a nice practical way of learning how to work with sockets in Python. In this post I'll share what I learned along the way.
The goal of my project was to write a simple server that I could use with a task queue project of mine called huey. Huey uses Redis as the default storage engine for tracking enqueued jobs, results of finished jobs, and other things. For the purposes of this post, I've reduced the scope of the original project even further so as not to muddy the waters with code you could very easily write yourself, but if you're curious, you can check out the end result here (documentation).
The server we'll be building will be able to respond to the following commands:
- GET
<key>
- SET
<key>
<value>
- DELETE
<key>
- FLUSH
- MGET
<key1>
...<keyn>
- MSET
<key1>
<value1>
...<keyn>
<valuen>
We'll support the following data-types as well:
- Strings and Binary Data
- Numbers
- NULL
- Arrays (which may be nested)
- Dictionaries (which may be nested)
- Error messages
To handle multiple clients asynchronously, we'll be using gevent, but you could also use the standard library's SocketServer module with either the ForkingMixin or the ThreadingMixin.
Skeleton
Let's frame up a skeleton for our server. We'll need the server itself, and a callback to be executed when a new client connects. Additionally we'll need some kind of logic to process the client request and to send a response.
Here's a start:
from gevent import socket
from gevent.pool import Pool
from gevent.server import StreamServer
from collections import namedtuple
from io import BytesIO
from socket import error as socket_error
# We'll use exceptions to notify the connection-handling loop of problems.
class CommandError(Exception): pass
class Disconnect(Exception): pass
Error = namedtuple('Error', ('message',))
class ProtocolHandler(object):
def handle_request(self, socket_file):
# Parse a request from the client into it's component parts.
pass
def write_response(self, socket_file, data):
# Serialize the response data and send it to the client.
pass
class Server(object):
def __init__(self, host='127.0.0.1', port=31337, max_clients=64):
self._pool = Pool(max_clients)
self._server = StreamServer(
(host, port),
self.connection_handler,
spawn=self._pool)
self._protocol = ProtocolHandler()
self._kv = {}
def connection_handler(self, conn, address):
# Convert "conn" (a socket object) into a file-like object.
socket_file = conn.makefile('rwb')
# Process client requests until client disconnects.
while True:
try:
data = self._protocol.handle_request(socket_file)
except Disconnect:
break
try:
resp = self.get_response(data)
except CommandError as exc:
resp = Error(exc.args[0])
self._protocol.write_response(socket_file, resp)
def get_response(self, data):
# Here we'll actually unpack the data sent by the client, execute the
# command they specified, and pass back the return value.
pass
def run(self):
self._server.serve_forever()
The above code is hopefully fairly clear. We've separated concerns so that the
protocol handling is in it's own class with two public methods:
handle_request
and write_response
. The server itself uses the protocol
handler to unpack client requests and serialize server responses back to the
client. The get_response()
method will be used to execute the command
initiatied by the client.
Taking a closer look at the code of the connection_handler()
method, you can
see that we obtain a file-like wrapper around the socket object. This wrapper
allows us to abstract away some of the quirks
one typically encounters working with raw sockets. The function enters an
endless loop, reading requests from the client, sending responses, and finally
exiting the loop when the client disconnects (indicated by read()
returning
an empty string).
We use typed exceptions to handle client disconnects and to notify the user of
errors processing commands. For example, if the user makes an improperly
formatted request to the server, we will raise a CommandError
, which is
serialized into an error response and sent to the client.
Before going further, let's discuss how the client and server will communicate.
Wire protocol
The first challenge I faced was how to handle sending binary data over the
wire. Most examples I found online were pointless echo servers that converted
the socket to a file-like object and just called readline()
. If I wanted to
store some pickled data or strings with new-lines, I would need to have some
kind of serialization format.
After wasting time trying to invent something suitable, I decided to read the documentation on the Redis protocol, which turned out to be very simple to implement and has the added benefit of supporting a couple different data-types.
The Redis protocol uses a request/response communication pattern with the clients. Responses from the server will use the first byte to indicate data-type, followed by the data, terminated by a carriage-return/line-feed.
Data-type | Prefix | Structure | Example |
---|---|---|---|
Simple string | + | +{string data}\r\n | +this is a simple string\r\n |
Error | – | -{error message}\r\n | -ERR unknown command "FLUHS"\r\n |
Integer | : | :{the number}\r\n | :1337\r\n |
Binary | $ | ${number of bytes}\r\n{data}\r\n | $6\r\n foobar\r\n |
Array | * | *{number of elements}\r\n{0 or more of above}\r\n | *3\r\n +a simple string element\r\n :12345\r\n $7\r\n testing\r\n |
Dictionary | % | %{number of keys}\r\n{0 or more of above}\r\n | %3\r\n +key1\r\n +value1\r\n +key2\r\n *2\r\n +value2-0\r\n +value2-1\r\n :3\r\n $7\r\n testing\r\n |
NULL | $ | $-1\r\n (string of length -1) | $-1\r\n |
Let's fill in the protocol handler's class so that it implements the Redis protocol.
class ProtocolHandler(object):
def __init__(self):
self.handlers = {
'+': self.handle_simple_string,
'-': self.handle_error,
':': self.handle_integer,
'$': self.handle_string,
'*': self.handle_array,
'%': self.handle_dict}
def handle_request(self, socket_file):
first_byte = socket_file.read(1)
if not first_byte:
raise Disconnect()
try:
# Delegate to the appropriate handler based on the first byte.
return self.handlers[first_byte](socket_file)
except KeyError:
raise CommandError('bad request')
def handle_simple_string(self, socket_file):
return socket_file.readline().rstrip('\r\n')
def handle_error(self, socket_file):
return Error(socket_file.readline().rstrip('\r\n'))
def handle_integer(self, socket_file):
return int(socket_file.readline().rstrip('\r\n'))
def handle_string(self, socket_file):
# First read the length ($<length>\r\n).
length = int(socket_file.readline().rstrip('\r\n'))
if length == -1:
return None # Special-case for NULLs.
length += 2 # Include the trailing \r\n in count.
return socket_file.read(length)[:-2]
def handle_array(self, socket_file):
num_elements = int(socket_file.readline().rstrip('\r\n'))
return [self.handle_request(socket_file) for _ in range(num_elements)]
def handle_dict(self, socket_file):
num_items = int(socket_file.readline().rstrip('\r\n'))
elements = [self.handle_request(socket_file)
for _ in range(num_items * 2)]
return dict(zip(elements[::2], elements[1::2]))
For the serialization side of the protocol, we'll do the opposite of the above: turn Python objects into their serialized counterparts!
class ProtocolHandler(object):
# ... above methods omitted ...
def write_response(self, socket_file, data):
buf = BytesIO()
self._write(buf, data)
buf.seek(0)
socket_file.write(buf.getvalue())
socket_file.flush()
def _write(self, buf, data):
if isinstance(data, str):
data = data.encode('utf-8')
if isinstance(data, bytes):
buf.write('$%s\r\n%s\r\n' % (len(data), data))
elif isinstance(data, int):
buf.write(':%s\r\n' % data)
elif isinstance(data, Error):
buf.write('-%s\r\n' % error.message)
elif isinstance(data, (list, tuple)):
buf.write('*%s\r\n' % len(data))
for item in data:
self._write(buf, item)
elif isinstance(data, dict):
buf.write('%%%s\r\n' % len(data))
for key in data:
self._write(buf, key)
self._write(buf, data[key])
elif data is None:
buf.write('$-1\r\n')
else:
raise CommandError('unrecognized type: %s' % type(data))
An additional benefit of keeping the protocol handling in its own class is that
we can re-use the handle_request
and write_response
methods to build a
client library.
Implementing Commands
The Server
class we mocked up now needs to have it's get_response()
method
implemented. Commands will be assumed to be sent by the client as either simple
strings or an array of command arguments, so the data
parameter passed to
get_response()
will either be bytes or a list. To simplify handling, if
data
is a simple string, we'll convert it to a list by splitting on
whitespace.
The first argument will be the command name, with any additional arguments
belonging to the specified command. As we did with the mapping of the first
byte to the handlers in the ProtocolHandler
, let's create a mapping of
command to callback in the Server
:
class Server(object):
def __init__(self, host='127.0.0.1', port=31337, max_clients=64):
self._pool = Pool(max_clients)
self._server = StreamServer(
(host, port),
self.connection_handler,
spawn=self._pool)
self._protocol = ProtocolHandler()
self._kv = {}
self._commands = self.get_commands()
def get_commands(self):
return {
'GET': self.get,
'SET': self.set,
'DELETE': self.delete,
'FLUSH': self.flush,
'MGET': self.mget,
'MSET': self.mset}
def get_response(self, data):
if not isinstance(data, list):
try:
data = data.split()
except:
raise CommandError('Request must be list or simple string.')
if not data:
raise CommandError('Missing command')
command = data[0].upper()
if command not in self._commands:
raise CommandError('Unrecognized command: %s' % command)
return self._commands[command](*data[1:])
Our server is almost finished! We just need to implement the six command
methods defined in the get_commands()
method:
class Server(object):
def get(self, key):
return self._kv.get(key)
def set(self, key, value):
self._kv[key] = value
return 1
def delete(self, key):
if key in self._kv:
del self._kv[key]
return 1
return 0
def flush(self):
kvlen = len(self._kv)
self._kv.clear()
return kvlen
def mget(self, *keys):
return [self._kv.get(key) for key in keys]
def mset(self, *items):
data = zip(items[::2], items[1::2])
for key, value in data:
self._kv[key] = value
return len(data)
That's it! Our server is now ready to start processing requests. In the next section we'll implement a client to interact with the server.
Client
To interact with the server, let's re-use the ProtocolHandler
class to
implement a simple client. The client will connect to the server and send
commands encoded as lists. We'll re-use both the write_response()
and the
handle_request()
logic for encoding requests and processing server responses
respectively.
class Client(object):
def __init__(self, host='127.0.0.1', port=31337):
self._protocol = ProtocolHandler()
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.connect((host, port))
self._fh = self._socket.makefile('rwb')
def execute(self, *args):
self._protocol.write_response(self._fh, args)
resp = self._protocol.handle_request(self._fh)
if isinstance(resp, Error):
raise CommandError(resp.message)
return resp
With the execute()
method, we can pass an arbitrary list of parameters which will be encoded as an array and sent to the server. The response from the server is parsed and returned as a Python object. For convenience, we can write client methods for the individual commands:
class Client(object):
# ...
def get(self, key):
return self.execute('GET', key)
def set(self, key, value):
return self.execute('SET', key, value)
def delete(self, key):
return self.execute('DELETE', key)
def flush(self):
return self.execute('FLUSH')
def mget(self, *keys):
return self.execute('MGET', *keys)
def mset(self, *items):
return self.execute('MSET', *items)
To test out our client, let's configure our Python script to start up a server when executed directly from the command-line:
# Add this to bottom of module:
if __name__ == '__main__':
from gevent import monkey; monkey.patch_all()
Server().run()
Testing the Server
To test the server, just execute the server's Python module from the command line. In another terminal, open up a Python interpreter and import the Client
class from the server's module. Instantiating the client will open a connection and you can start running commands!
>>> from server_ex import Client
>>> client = Client()
>>> client.mset('k1', 'v1', 'k2', ['v2-0', 1, 'v2-2'], 'k3', 'v3')
3
>>> client.get('k2')
['v2-0', 1, 'v2-2']
>>> client.mget('k3', 'k1')
['v3', 'v1']
>>> client.delete('k1')
1
>>> client.get('k1')
>>> client.delete('k1')
0
>>> client.set('kx', {'vx': {'vy': 0, 'vz': [1, 2, 3]}})
1
>>> client.get('kx')
{'vx': {'vy': 0, 'vz': [1, 2, 3]}}
>>> client.flush()
2
The code presented in this post is absolutely for demonstration purposes only. I hope you enjoyed reading about this project as much as I enjoyed writing about it. You can find a complete copy of the code here. To see a more complete example, check out simpledb (documentation).
To extend the project, you might consider:
- Add more commands!
- Use the protocol handler to implement an append-only command log
- More robust error handling
- Allow client to close connection and re-connect
- Logging
- Re-write to use the standard library's
SocketServer
andThreadingMixin
Comments (3)
Charles | jan 03 2018, at 12:48pm
As written all data is stored in memory (in ._kv
), though a suggested exercise is to re-use the protocol handler to write an append-only log so you could thereby reconstruct the in-memory data-set by replaying the log.
Marlysson | jan 03 2018, at 06:52am
Awesome idea, good job man, It handle with some type of either cache or "just" works similar to dict in memory?
Commenting has been closed.
Sami | jan 05 2018, at 12:17am
If you use CLOCK-Pro algorithm with (or instead) dictionary. Then you'll get cache functionality, which will evict least needed data on overflow. https://bitbucket.org/SamiLehtinen/pyclockpro