diff --git a/scribe/db/interface.py b/scribe/db/interface.py index b4f430b..5045705 100644 --- a/scribe/db/interface.py +++ b/scribe/db/interface.py @@ -101,6 +101,9 @@ class PrefixRow(metaclass=PrefixRowType): handle_value(result[packed_keys[tuple(k_args)]]) for k_args in key_args ] + def stage_multi_put(self, items): + self._op_stack.multi_put([RevertablePut(self.pack_key(*k), self.pack_value(*v)) for k, v in items]) + def get_pending(self, *key_args, fill_cache=True, deserialize_value=True): packed_key = self.pack_key(*key_args) last_op = self._op_stack.get_last_op_for_key(packed_key) @@ -178,7 +181,7 @@ class BasePrefixDB: cf = self._db.get_column_family(prefix.value) self.column_families[prefix.value] = cf - self._op_stack = RevertableOpStack(self.get, unsafe_prefixes=unsafe_prefixes) + self._op_stack = RevertableOpStack(self.get, self.multi_get, unsafe_prefixes=unsafe_prefixes) self._max_undo_depth = max_undo_depth def unsafe_commit(self): @@ -259,6 +262,17 @@ class BasePrefixDB: cf = self.column_families[key[:1]] return self._db.get((cf, key), fill_cache=fill_cache) + def multi_get(self, keys: typing.List[bytes], fill_cache=True): + first_key = keys[0] + if not all(first_key[0] == key[0] for key in keys): + raise ValueError('cannot multi-delete across column families') + cf = self.column_families[first_key[:1]] + db_result = self._db.multi_get([(cf, k) for k in keys], fill_cache=fill_cache) + return list(db_result.values()) + + def multi_delete(self, items: typing.List[typing.Tuple[bytes, bytes]]): + self._op_stack.multi_delete([RevertableDelete(k, v) for k, v in items]) + def iterator(self, start: bytes, column_family: 'rocksdb.ColumnFamilyHandle' = None, iterate_lower_bound: bytes = None, iterate_upper_bound: bytes = None, reverse: bool = False, include_key: bool = True, include_value: bool = True, diff --git a/scribe/db/revertable.py b/scribe/db/revertable.py index 64e1d88..a982a97 100644 --- a/scribe/db/revertable.py +++ b/scribe/db/revertable.py @@ -2,7 +2,7 @@ import struct import logging from string import printable from collections import defaultdict -from typing import Tuple, Iterable, Callable, Optional +from typing import Tuple, Iterable, Callable, Optional, List from scribe.db.common import DB_PREFIXES _OP_STRUCT = struct.Struct('>BLL') @@ -82,7 +82,8 @@ class OpStackIntegrity(Exception): class RevertableOpStack: - def __init__(self, get_fn: Callable[[bytes], Optional[bytes]], unsafe_prefixes=None): + def __init__(self, get_fn: Callable[[bytes], Optional[bytes]], + multi_get_fn: Callable[[List[bytes]], Iterable[Optional[bytes]]], unsafe_prefixes=None): """ This represents a sequence of revertable puts and deletes to a key-value database that checks for integrity violations when applying the puts and deletes. The integrity checks assure that keys that do not exist @@ -95,6 +96,7 @@ class RevertableOpStack: :param unsafe_prefixes: optional set of prefixes to ignore integrity errors for, violations are still logged """ self._get = get_fn + self._multi_get = multi_get_fn self._items = defaultdict(list) self._unsafe_prefixes = unsafe_prefixes or set() @@ -133,6 +135,88 @@ class RevertableOpStack: raise err self._items[op.key].append(op) + def multi_put(self, ops: List[RevertablePut]): + """ + Apply a put or delete op, checking that it introduces no integrity errors + """ + + if not ops: + return + + need_put = [] + + if not all(op.is_put for op in ops): + raise ValueError(f"list must contain only puts") + if not len(set(map(lambda op: op.key, ops))) == len(ops): + raise ValueError(f"list must contain unique keys") + + for op in ops: + if self._items[op.key] and op.invert() == self._items[op.key][-1]: + self._items[op.key].pop() # if the new op is the inverse of the last op, we can safely null both + continue + elif self._items[op.key] and self._items[op.key][-1] == op: # duplicate of last op + continue # raise an error? + else: + need_put.append(op) + + for op, stored_val in zip(need_put, self._multi_get(list(map(lambda item: item.key, need_put)))): + has_stored_val = stored_val is not None + delete_stored_op = None if not has_stored_val else RevertableDelete(op.key, stored_val) + will_delete_existing_stored = False if delete_stored_op is None else (delete_stored_op in self._items[op.key]) + try: + if has_stored_val and not will_delete_existing_stored: + raise OpStackIntegrity(f"db op tries to overwrite before deleting existing: {op}") + except OpStackIntegrity as err: + if op.key[:1] in self._unsafe_prefixes: + log.debug(f"skipping over integrity error: {err}") + else: + raise err + self._items[op.key].append(op) + + def multi_delete(self, ops: List[RevertableDelete]): + """ + Apply a put or delete op, checking that it introduces no integrity errors + """ + + if not ops: + return + + need_delete = [] + + if not all(op.is_delete for op in ops): + raise ValueError(f"list must contain only deletes") + if not len(set(map(lambda op: op.key, ops))) == len(ops): + raise ValueError(f"list must contain unique keys") + + for op in ops: + if self._items[op.key] and op.invert() == self._items[op.key][-1]: + self._items[op.key].pop() # if the new op is the inverse of the last op, we can safely null both + continue + elif self._items[op.key] and self._items[op.key][-1] == op: # duplicate of last op + continue # raise an error? + else: + need_delete.append(op) + + for op, stored_val in zip(need_delete, self._multi_get(list(map(lambda item: item.key, need_delete)))): + has_stored_val = stored_val is not None + delete_stored_op = None if not has_stored_val else RevertableDelete(op.key, stored_val) + will_delete_existing_stored = False if delete_stored_op is None else (delete_stored_op in self._items[op.key]) + try: + if op.is_delete and has_stored_val and stored_val != op.value and not will_delete_existing_stored: + # there is a value and we're not deleting it in this op + # check that a delete for the stored value is in the stack + raise OpStackIntegrity(f"db op tries to delete with incorrect existing value {op}") + elif not stored_val: + raise OpStackIntegrity(f"db op tries to delete nonexistent key: {op}") + elif op.is_delete and stored_val != op.value: + raise OpStackIntegrity(f"db op tries to delete with incorrect value: {op}") + except OpStackIntegrity as err: + if op.key[:1] in self._unsafe_prefixes: + log.debug(f"skipping over integrity error: {err}") + else: + raise err + self._items[op.key].append(op) + def extend_ops(self, ops: Iterable[RevertableOp]): """ Apply a sequence of put or delete ops, checking that they introduce no integrity errors