Window functions, case statements, and savepoints in peewee
In case you've missed the last few releases, I've been busy adding some fun new features to peewee. While the changelog and the docs explain the new features and describe their usage, I thought I'd write a blog post to provide a bit more context.
Most of these features were requested by peewee users. I depend heavily on users like you to help me improve peewee, so thank you very much! Not only have your feature requests helped make peewee a better library, they've helped me become a better programmer.
So what's new in peewee? Here is something of an overview:
- Window functions.
- CASE statements.
- Savepoints for nested transactions.
- Array and JSON fields with Postgresql.
- Union, Intersect and Except compound queries.
Hopefully some of those things sound interesting. In this post I will not be discussing everything, but will hit some of the highlights.
Someone recently asked me whether peewee supported Postgresql's window functions. At the time peewee did not, as I kind of had mentally lumped window functions together under the "exotic SQL things you'll never use" category. After spending some time reading, I changed my mind and decided peewee absolutely needed to support window functions (I wanted to start using them on a side project, of course).
If you're also new to window functions, there are many excellent guides online -- I really enjoyed these slides from a PGCon talk by Hitoshi Harada. My understanding is that window functions allow you mix multi-row aggregations with individual result rows. So window functions are kind of like aggregations, except applied to a subset of the result set, and then available throughout the result set.
Let's say we're tracking page-views on our site and want to look at response times to see which pages are slow. To get the average response time for a page, you would use normal SQL aggregates to write something like:
query = (PageView .select(PageView.url, fn.AVG(PageView.response_time)) .group_by(PageView.url) .order_by(fn.AVG(PageView.response_time).desc()) .tuples())
This type of query is fine, but it doesn't really give us an indication of how each individual page-view contributed to the average. It would be cool if we could get an idea of which individual page-views were slower than average.
What if we wanted to list out all the PageViews, sorted by how much slower they were than the average response time of requests to the same URL? That is, display the difference between the PageView's response time and the average response time for other requests to the same URL. To do that, we can use window functions. We will take the average of the response time then partition it by the requested URL:
window = fn.AVG(PageView.response_time).over(partition_by=[PageView.url]) query = (PageView .select( PageView.url, PageView.response_time, window.alias('avg_response'), (PageView.response_time - window).alias('difference'), PageView.timestamp) .order_by(SQL('difference').desc())) for url, response_time, avg_response, diff, timestamp in query.tuples(): print url, response_time, avg_response, diff, timestamp
Here is a sampling of what the results might look like. The
calculated for each distinct URL, then included in the result rows. We can also
compare the response time with the average, then sort by that value. The table below
shows which page-views were significantly slower than the average for that URL.
We can see that the worst offenders all occurred around 4 A.M., which coincidentally
is when the backups are run.
url | resp_time | avg_response | diff | timestamp -----------+-----------+--------------+--------+--------------------- /blog/ | 2.5 | 1.099 | 1.400 | 2014-01-02 04:xx:xx /about/ | 4.3 | 3.342 | 0.957 | 2014-01-02 04:xx:xx /blog/ | 2 | 1.099 | 0.900 | 2014-01-02 04:xx:xx /about/ | 4 | 3.342 | 0.657 | 2014-01-02 04:xx:xx /about/ | 3.9 | 3.342 | 0.557 | 2014-01-02 04:xx:xx /blog/ | 1.5 | 1.099 | 0.400 | 2014-01-02 04:xx:xx /blog/ | 1.5 | 1.099 | 0.400 | 2014-01-01 01:xx:xx /about/ | 3.7 | 3.342 | 0.357 | 2014-01-02 17:xx:xx /blog/ | 1.4 | 1.099 | 0.299 | 2014-01-01 22:xx:xx /blog/ | 1.4 | 1.099 | 0.299 | 2014-01-01 11:xx:xx /contact/ | 2.5 | 2.308 | 0.191 | 2014-01-01 07:xx:xx
Similarly, what if we wanted to compare the response time for a request with the
average response time for all other requests on the same day? To do this we will
look at a window containing the average response times over the course of a
single day (using the SQL
window = (fn.AVG(PageView.response_time) .over(partition_by=[fn.date_trunc('day', PageView.timestamp)])) query = (PageView .select( PageView.url, PageView.response_time, window.alias('avg_response'), PageView.timestamp) .order_by(PageView.timestamp)) for url, response_time, avg_response, timestamp in query.tuples(): print url, response_time, avg_response, timestamp
Here are some sample results - the average here is for the day on which the PageView occurred:
url | resp_time | avg_response | timestamp -----------+-----------+--------------+--------------------- /blog/ | 0.7 | 1.535 | 2014-01-01 xx:xx:xx /contact/ | 2.5 | 1.535 | 2014-01-01 xx:xx:xx /blog/ | 1.4 | 1.535 | 2014-01-01 xx:xx:xx /about/ | 3.2 | 1.535 | 2014-01-01 xx:xx:xx ... /contact/ | 2.125 | 1.744 | 2014-01-02 xx:xx:xx /blog/ | 1.5 | 1.744 | 2014-01-02 xx:xx:xx /about/ | 4 | 1.744 | 2014-01-02 xx:xx:xx /blog/ | 1 | 1.744 | 2014-01-02 xx:xx:xx
We can combine the two to compare a single response time with the average
response time for all other requests to the same URL, on the same day. To do
this we will just include both the URL and the
date_trunc function in the
clause when setting up our window:
window = fn.AVG(PageView.response_time).over( partition_by=[PageView.url, fn.date_trunc('day', PageView.timestamp)]), query = (PageView .select( PageView.url, PageView.response_time, window.alias('avg_response'), PageView.timestamp) .order_by(PageView.timestamp))
This gives us an
avg_response for each URL + day:
url | resp_time | avg_response | timestamp -----------+-----------+--------------+--------------------- /blog/ | 0.7 | 1.149 | 2014-01-01 xx:xx:xx /blog/ | 0.9 | 1.149 | 2014-01-01 xx:xx:xx /contact/ | 2.6 | 2.6 | 2014-01-01 xx:xx:xx /about/ | 2 | 2.5 | 2014-01-01 xx:xx:xx /about/ | 2.3 | 2.5 | 2014-01-01 xx:xx:xx ... /blog/ | 1 | 1.05 | 2014-01-02 xx:xx:xx /blog/ | 1.1 | 1.05 | 2014-01-02 xx:xx:xx /contact/ | 2.125 | 2.212 | 2014-01-02 17:00:00 /about/ | 4 | 3.975 | 2014-01-02 18:00:00 /about/ | 3.7 | 3.975 | 2014-01-02 15:00:00
These examples all show how you can perform aggregates on partitions of the result set, but that is only half of what window functions can do. Window functions can also look at rows in relation to other rows in the result set.
Suppose I'm running a race and the runners are in one of two categories (youth and adult). I've got their time (in seconds) and category, and would like to let each runner know what order they finished in for their category. The raw data looks like this:
finish_time | category -------------+---------- 339 | adult 366 | adult 372 | adult 384 | adult 407 | adult 412 | adult 414 | adult 447 | adult 484 | adult 315 | youth 321 | youth 362 | youth 363 | youth 366 | youth 430 | youth 449 | youth 467 | youth 476 | youth 485 | youth 485 | youth
As you guessed, we'll be using window functions. We will use the
to get an absolute ranking:
query = (Times .select( Times.finish_time, Times.category, fn.rank().over( partition_by=[Times.category], order_by=[Times.finish_time]).alias('place')) .order_by(Times.category, SQL('place'))) for ft, cat, place in query.tuples(): print ft, cat, place
This prints the following table. Note that two runners in the youth category had the same time (485) -- both runners are given 10th place:
finish_time | category | place -------------+----------+------- 339 | adult | 1 366 | adult | 2 372 | adult | 3 384 | adult | 4 407 | adult | 5 412 | adult | 6 414 | adult | 7 447 | adult | 8 484 | adult | 9 315 | youth | 1 321 | youth | 2 362 | youth | 3 363 | youth | 4 366 | youth | 5 430 | youth | 6 449 | youth | 7 467 | youth | 8 476 | youth | 9 485 | youth | 10 485 | youth | 10
What if we wanted to know how much each runner trailed the runner ahead of them?
That is, we want the gaps in each finish time. For this we can use the
function, which returns us a value from a preceding row. We will then subtract
the previous runner's time from the current time to get the difference:
window = fn.lag(Times.finish_time).over( partition_by=[Times.category], order_by=[Times.finish_time]) query = (Times .select( Times.finish_time, Times.category, (Times.finish_time - window).alias('time_diff')) .order_by(Times.category, Times.finish_time))
This yields the following table. The first-place runner in each category has
no value in the
time_diff column, since nobody finished before them.
finish_time | category | time_diff -------------+----------+----------- 339 | adult | 366 | adult | 27 372 | adult | 6 384 | adult | 12 407 | adult | 23 412 | adult | 5 414 | adult | 2 447 | adult | 33 484 | adult | 37 315 | youth | 321 | youth | 6 362 | youth | 41 363 | youth | 1 366 | youth | 3 430 | youth | 64 449 | youth | 19 467 | youth | 18 476 | youth | 9 485 | youth | 9 485 | youth | 0
There are all sorts of special window-specific aggregates in addition to
rank. For a comprehensive list, check out the Postgresql docs.
I'm still learning my way around, but so far I've had a fun time experimenting with this feature!
Somehow I've managed to scrape by the past few years without really needing to
CASE statements in my queries. Then a couple months ago I had a pretty
tricky query I needed to run, and after struggling for a while, I found that by
CASE statement I could simplify things considerably.
So the setup is: let's say you're Amazon and you are shipping a person's order from several different warehouses, in several different shipments. You want to know when the person's order has been fully shipped or otherwise reached some "final" state. We'll say this could happen when each package's status becomes either "shipped" or "undeliverable".
Here is our simple data model:
class Order(Model): recipient = CharField() class Package(Model): order = ForeignKeyField(Order) description = CharField() status = CharField()
The way I went about getting this list was to calculate the number of Packages associated with the order, then check that this Count equaled the number of Packages in a terminal state. Here is the SQL:
SELECT t1."id", t1."recipient" FROM "order" as t1 INNER JOIN "package" AS t2 ON (t1."id" = t2."order_id") GROUP BY t1."id", t1."recipient" HAVING ( COUNT(t2."id") = SUM( CASE WHEN (t2."status" IN ('shipped', 'undeliverable')) THEN 1 ELSE 0 END ))
To express this using peewee, I wrote a little helper in
to generate the
# Represents "status" IN ('shipped', 'undeliverable') predicate = Package.status << ['shipped', 'undeliverable'] shipped = (Order .select() .join(Package) .group_by(Order) .having( fn.COUNT(Package.id) == fn.SUM(case(None, [(predicate, 1)], 0)) ))
If you'd like to see more examples of the
case helper in peewee, check
out the case documentation.
And yes, this query could also be done with some double negation:
# Packages for the given order that are not in the shipped or # undeliverable state. subquery = (Package .select() .where( (Package.order == Order.id) & ~(Package.status << ['shipped', 'undeliverable']))) # Orders for which no packages exist which are *not* shipped # or undeliverable. Also do an INNER join to ensure that at # least one Package exists. shipped = (Order .select() .join(Package) .where(~Clause(SQL('EXISTS'), subquery)) .distinct())
The last feature I'll mention here is savepoints, which
are supported by Postgresql, MySQL and Sqlite. Savepoints
basically allow you to nest transactions. The peewee implementation uses a similar
API to transactions. Here is how you might use savepoints to implement
class User(Model): email = CharField(unique=True) def get_or_create(email): try: with db.savepoint(): return User.create(email=email) except IntegrityError: return User.get(User.email == email)
In the example above, the
get_or_create function will create a new savepoint
(so as not to disturb any existing transaction) and then attempt to insert a
new row. If the database raises an
IntegrityError, we know that the email is
already in use, so we can catch the exception and return the existing row. Pretty
This post was getting a bit long so I'm cutting it here, but I hope to write about the other new features I didn't get to cover (arrays, json fields, compound queries).
In terms of peewee, I'm always interested in ideas for useful features so please don't be shy if you think something's missing from the library. One that has come up several times recently is connection pooling, so be on the lookout for that in the next couple of releases.
Thanks for taking the time to read this post, I hope you found it interesting.
Feel free to leave a comment below. If you have a specific question about peewee,
you can post it to the mailing list or ask in
#peewee on freenode.
If you'd like to read more, the following links may be of interest:
Commenting has been closed, but please feel free to contact me