Merge lp:~ack/storm/postgres-case-expr into lp:storm

Proposed by Alberto Donato
Status: Merged
Merged at revision: 475
Proposed branch: lp:~ack/storm/postgres-case-expr
Merge into: lp:storm
Diff against target: 104 lines (+75/-1)
2 files modified
storm/databases/postgres.py (+35/-0)
tests/databases/postgres.py (+40/-1)
To merge this branch: bzr merge lp:~ack/storm/postgres-case-expr
Reviewer Review Type Date Requested Status
Free Ekanayaka (community) Approve
Review via email: mp+261176@code.launchpad.net

Description of the change

This adds support for Postgres CASE expression.

To post a comment you must log in.
Revision history for this message
Free Ekanayaka (free.ekanayaka) wrote :

+1 with a small wording fix

review: Approve
lp:~ack/storm/postgres-case-expr updated
477. By Alberto Donato

Fix docstring (Free's review).

Preview Diff

[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1=== modified file 'storm/databases/postgres.py'
2--- storm/databases/postgres.py 2013-09-25 22:06:34 +0000
3+++ storm/databases/postgres.py 2015-06-05 07:28:10 +0000
4@@ -80,6 +80,41 @@
5 return "%s RETURNING %s" % (expr, columns)
6
7
8+class Case(Expr):
9+ """A CASE statement.
10+
11+ @params cases: a list of tuples of (condition, result) or (value, result),
12+ if an expression is passed too.
13+ @param expression: the expression to compare (if the simple form is used).
14+ @param default: an optional default condition if no other case matches.
15+
16+ """
17+ def __init__(self, cases, expression=Undef, default=Undef):
18+ self.cases = cases
19+ self.expression = expression
20+ self.default = default
21+
22+
23+@compile.when(Case)
24+def compile_case(compile, expr, state):
25+ cases = [
26+ "WHEN %s THEN %s" % (
27+ compile(condition, state), compile(value, state))
28+ for condition, value in expr.cases]
29+
30+ if expr.expression is not Undef:
31+ expression = compile(expr.expression, state) + " "
32+ else:
33+ expression = ""
34+
35+ if expr.default is not Undef:
36+ default = " ELSE %s" % compile(expr.default, state)
37+ else:
38+ default = ""
39+
40+ return "CASE %s%s%s END" % (expression, " ".join(cases), default)
41+
42+
43 class currval(FuncExpr):
44
45 name = "currval"
46
47=== modified file 'tests/databases/postgres.py'
48--- tests/databases/postgres.py 2015-06-03 16:01:13 +0000
49+++ tests/databases/postgres.py 2015-06-05 07:28:10 +0000
50@@ -22,7 +22,8 @@
51 import os
52
53 from storm.databases.postgres import (
54- Postgres, compile, currval, Returning, PostgresTimeoutTracer, make_dsn)
55+ Postgres, compile, currval, Returning, Case, PostgresTimeoutTracer,
56+ make_dsn)
57 from storm.database import create_database
58 from storm.exceptions import InterfaceError, ProgrammingError
59 from storm.variables import DateTimeVariable, RawStrVariable
60@@ -347,6 +348,44 @@
61 'SELECT "my schema"."my table"."my.column" '
62 'FROM "my schema"."my table"')
63
64+ def test_compile_case(self):
65+ """The Case expr is compiled in a Postgres' CASE expression."""
66+ cases = [
67+ (Column("foo") > 3, u"big"), (Column("bar") == None, 4)]
68+ state = State()
69+ statement = compile(Case(cases), state)
70+ self.assertEqual(
71+ "CASE WHEN (foo > ?) THEN ? WHEN (bar IS NULL) THEN ? END",
72+ statement)
73+ self.assertEqual(
74+ [3, "big", 4], [param.get() for param in state.parameters])
75+
76+ def test_compile_case_with_default(self):
77+ """
78+ If a default is provided, the resulting CASE expression includes
79+ an ELSE clause.
80+ """
81+ cases = [(Column("foo") > 3, u"big")]
82+ state = State()
83+ statement = compile(Case(cases, default=9), state)
84+ self.assertEqual(
85+ "CASE WHEN (foo > ?) THEN ? ELSE ? END", statement)
86+ self.assertEqual(
87+ [3, "big", 9], [param.get() for param in state.parameters])
88+
89+ def test_compile_case_with_expression(self):
90+ """
91+ If an expression is provided, the resulting CASE expression uses the
92+ simple syntax.
93+ """
94+ cases = [(1, u"one"), (2, u"two")]
95+ state = State()
96+ statement = compile(Case(cases, expression=Column("foo")), state)
97+ self.assertEqual(
98+ "CASE foo WHEN ? THEN ? WHEN ? THEN ? END", statement)
99+ self.assertEqual(
100+ [1, "one", 2, "two"], [param.get() for param in state.parameters])
101+
102 def test_currval_no_escaping(self):
103 expr = currval(Column("thecolumn", "theschema.thetable"))
104 statement = compile(expr)

Subscribers

People subscribed via source and target branches

to status/vote changes: