From 71551058cefd48e405353fbfd33685e9c6934675 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Wed, 17 Jul 2019 21:46:38 -0400 Subject: [PATCH] query interpolation --- .../tests/client_tests/unit/test_database.py | 23 +++++++++++++++++-- torba/torba/client/basedatabase.py | 14 +++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/torba/tests/client_tests/unit/test_database.py b/torba/tests/client_tests/unit/test_database.py index 240933cc0..d5d45576d 100644 --- a/torba/tests/client_tests/unit/test_database.py +++ b/torba/tests/client_tests/unit/test_database.py @@ -1,11 +1,10 @@ import unittest import sqlite3 -from functools import wraps from torba.client.wallet import Wallet from torba.client.constants import COIN from torba.coin.bitcoinsegwit import MainNetLedger as ledger_class -from torba.client.basedatabase import query, constraints_to_sql, AIOSQLite +from torba.client.basedatabase import query, interpolate, constraints_to_sql, AIOSQLite from torba.client.hash import sha256 from torba.testcase import AsyncioTestCase @@ -167,6 +166,26 @@ class TestQueryBuilder(unittest.TestCase): ("select * from foo LIMIT 20 OFFSET 10", {}) ) + def test_query_interpolation(self): + self.maxDiff = None + # tests that interpolation replaces longer keys first + self.assertEqual( + interpolate(*query( + "select * from foo", + a__not='b', b__in='select * from blah where c=:$c', + d__any={'one__like': 'o', 'two': 2}, + a0=3, a00=1, a00a=2, a00aa=4, # <-- breaks without correct interpolation key order + ahash=memoryview(sha256(b'hello world')), + limit=10, order_by='b', **{'$c': 3}) + ), + "select * from foo WHERE a != 'b' AND " + "b IN (select * from blah where c=3) AND " + "(one LIKE 'o' OR two = 2) AND " + "a0 = 3 AND a00 = 1 AND a00a = 2 AND a00aa = 4 " + "AND ahash = e9cdefe2acf78890ee80537ae3ef84c4faab7ddad7522ea5083e4d93b9274db9 " + "ORDER BY b LIMIT 10", + ) + class TestQueries(AsyncioTestCase): diff --git a/torba/torba/client/basedatabase.py b/torba/torba/client/basedatabase.py index a431291f9..673d737f1 100644 --- a/torba/torba/client/basedatabase.py +++ b/torba/torba/client/basedatabase.py @@ -1,5 +1,6 @@ import logging import asyncio +from binascii import hexlify from asyncio import wrap_future from concurrent.futures.thread import ThreadPoolExecutor @@ -191,6 +192,19 @@ def query(select, **constraints): return ' '.join(sql), values +def interpolate(sql, values): + for k in sorted(values.keys(), reverse=True): + value = values[k] + if isinstance(value, memoryview): + value = hexlify(bytes(value)[::-1]).decode() + elif isinstance(value, str): + value = f"'{value}'" + else: + value = str(value) + sql = sql.replace(f":{k}", value) + return sql + + def rows_to_dict(rows, fields): if rows: return [dict(zip(fields, r)) for r in rows]