diff --git a/scribe/db/db.py b/scribe/db/db.py index cfa9819..6323f79 100644 --- a/scribe/db/db.py +++ b/scribe/db/db.py @@ -327,7 +327,7 @@ class HubDB: if blocker_hash: reason_row = self._fs_get_claim_by_hash(blocker_hash) return ExpandedResolveResult( - None, ResolveCensoredError(url, blocker_hash, censor_row=reason_row), None, None + None, ResolveCensoredError(url, blocker_hash.hex(), censor_row=reason_row), None, None ) if claim.reposted_claim_hash: repost = self._fs_get_claim_by_hash(claim.reposted_claim_hash) diff --git a/scribe/elasticsearch/search.py b/scribe/elasticsearch/search.py index 3a95445..c038756 100644 --- a/scribe/elasticsearch/search.py +++ b/scribe/elasticsearch/search.py @@ -10,7 +10,7 @@ from typing import Optional, List, Iterable from elasticsearch import AsyncElasticsearch, NotFoundError, ConnectionError from elasticsearch.helpers import async_streaming_bulk -from scribe.schema.result import Outputs, Censor +from scribe.schema.result import Censor, Outputs from scribe.schema.tags import clean_tags from scribe.schema.url import normalize_name from scribe.error import TooManyClaimSearchParametersError @@ -285,18 +285,21 @@ class SearchIndex: async with cache_item.lock: if cache_item.result: return cache_item.result - censor = Censor(Censor.SEARCH) response, offset, total = await self.search(**kwargs) - censor.apply(response) + censored = {} + for row in response: + if (row.get('censor_type') or 0) >= Censor.SEARCH: + censoring_channel_hash = bytes.fromhex(row['censoring_channel_id'])[::-1] + censored.setdefault(censoring_channel_hash, set()) + censored[censoring_channel_hash].add(row['tx_hash']) total_referenced.extend(response) - - if censor.censored: + if censored: response, _, _ = await self.search(**kwargs, censor_type=Censor.NOT_CENSORED) total_referenced.extend(response) response = [self._make_resolve_result(r) for r in response] extra = [self._make_resolve_result(r) for r in await self._get_referenced_rows(total_referenced)] result = Outputs.to_base64( - response, extra, offset, total, censor + response, extra, offset, total, censored ) cache_item.result = result return result @@ -314,7 +317,6 @@ class SearchIndex: for result in expand_result(filter(lambda doc: doc['found'], results["docs"])): self.claim_cache.set(result['claim_id'], result) - async def search(self, **kwargs): try: return await self.search_ahead(**kwargs) diff --git a/scribe/schema/result.py b/scribe/schema/result.py index e51c86d..2429e93 100644 --- a/scribe/schema/result.py +++ b/scribe/schema/result.py @@ -1,5 +1,5 @@ import base64 -from typing import List, TYPE_CHECKING, Union, Optional, NamedTuple +from typing import List, TYPE_CHECKING, Union, Optional, Dict, Set, Tuple from itertools import chain from scribe.error import ResolveCensoredError @@ -12,53 +12,46 @@ NOT_FOUND = ErrorMessage.Code.Name(ErrorMessage.NOT_FOUND) BLOCKED = ErrorMessage.Code.Name(ErrorMessage.BLOCKED) -def set_reference(reference, claim_hash, rows): - if claim_hash: - for txo in rows: - if claim_hash == txo.claim_hash: - reference.tx_hash = txo.tx_hash - reference.nout = txo.position - reference.height = txo.height - return - - class Censor: - NOT_CENSORED = 0 SEARCH = 1 RESOLVE = 2 - __slots__ = 'censor_type', 'censored' - def __init__(self, censor_type): - self.censor_type = censor_type - self.censored = {} +def encode_txo(txo_message: OutputsMessage, resolve_result: Union['ResolveResult', Exception]): + if isinstance(resolve_result, Exception): + txo_message.error.text = resolve_result.args[0] + if isinstance(resolve_result, ValueError): + txo_message.error.code = ErrorMessage.INVALID + elif isinstance(resolve_result, LookupError): + txo_message.error.code = ErrorMessage.NOT_FOUND + return + txo_message.tx_hash = resolve_result.tx_hash + txo_message.nout = resolve_result.position + txo_message.height = resolve_result.height + txo_message.claim.short_url = resolve_result.short_url + txo_message.claim.reposted = resolve_result.reposted + txo_message.claim.is_controlling = resolve_result.is_controlling + txo_message.claim.creation_height = resolve_result.creation_height + txo_message.claim.activation_height = resolve_result.activation_height + txo_message.claim.expiration_height = resolve_result.expiration_height + txo_message.claim.effective_amount = resolve_result.effective_amount + txo_message.claim.support_amount = resolve_result.support_amount - def is_censored(self, row): - return (row.get('censor_type') or self.NOT_CENSORED) >= self.censor_type - - def apply(self, rows): - return [row for row in rows if not self.censor(row)] - - def censor(self, row) -> Optional[bytes]: - if self.is_censored(row): - censoring_channel_hash = bytes.fromhex(row['censoring_channel_id'])[::-1] - self.censored.setdefault(censoring_channel_hash, set()) - self.censored[censoring_channel_hash].add(row['tx_hash']) - return censoring_channel_hash - return None - - def to_message(self, outputs: OutputsMessage, extra_txo_rows: List['ResolveResult']): - for censoring_channel_hash, count in self.censored.items(): - outputs.blocked_total += len(count) - blocked = outputs.blocked.add() - blocked.count = len(count) - for resolve_result in extra_txo_rows: - if resolve_result.claim_hash == censoring_channel_hash: - blocked.channel.tx_hash = resolve_result.tx_hash - blocked.channel.nout = resolve_result.position - blocked.channel.height = resolve_result.height - return + if resolve_result.canonical_url is not None: + txo_message.claim.canonical_url = resolve_result.canonical_url + if resolve_result.last_takeover_height is not None: + txo_message.claim.take_over_height = resolve_result.last_takeover_height + if resolve_result.claims_in_channel is not None: + txo_message.claim.claims_in_channel = resolve_result.claims_in_channel + if resolve_result.reposted_claim_hash and resolve_result.reposted_tx_hash is not None: + txo_message.claim.repost.tx_hash = resolve_result.reposted_tx_hash + txo_message.claim.repost.nout = resolve_result.reposted_tx_position + txo_message.claim.repost.height = resolve_result.reposted_height + if resolve_result.channel_hash and resolve_result.channel_tx_hash is not None: + txo_message.claim.channel.tx_hash = resolve_result.channel_tx_hash + txo_message.claim.channel.nout = resolve_result.channel_tx_position + txo_message.claim.channel.height = resolve_result.channel_height class Outputs: @@ -170,67 +163,45 @@ class Outputs: outputs.blocked, outputs.blocked_total ) - @classmethod - def to_base64(cls, txo_rows, extra_txo_rows, offset=0, total=None, blocked: Censor = None) -> str: - return base64.b64encode(cls.to_bytes(txo_rows, extra_txo_rows, offset, total, blocked)).decode() + @staticmethod + def to_base64(txo_rows: List[Union[Exception, 'ResolveResult']], extra_txo_rows: List['ResolveResult'], + offset: int = 0, total: Optional[int] = None, + censored: Optional[Dict[bytes, Set[bytes]]] = None) -> str: + return base64.b64encode(Outputs.to_bytes(txo_rows, extra_txo_rows, offset, total, censored)).decode() - @classmethod - def to_bytes(cls, txo_rows, extra_txo_rows, offset=0, total=None, blocked: Censor = None) -> bytes: + @staticmethod + def to_bytes(txo_rows: List[Union[Exception, 'ResolveResult']], extra_txo_rows: List['ResolveResult'], + offset: int = 0, total: Optional[int] = None, + censored: Optional[Dict[bytes, Set[bytes]]] = None) -> bytes: page = OutputsMessage() page.offset = offset if total is not None: page.total = total - if blocked is not None: - blocked.to_message(page, extra_txo_rows) - for row in extra_txo_rows: - cls.encode_txo(page.extra_txos.add(), row) + censored = censored or {} + censored_txos: Dict[bytes, List[Tuple[str, 'ResolveResult']]] = {} for row in txo_rows: txo_message = page.txos.add() if isinstance(row, ResolveCensoredError): - for resolve_result in extra_txo_rows: - if resolve_result.claim_hash == row.censor_id: - txo_message.error.code = ErrorMessage.BLOCKED - txo_message.error.text = str(row) - txo_message.error.blocked.channel.tx_hash = resolve_result.tx_hash - txo_message.error.blocked.channel.nout = resolve_result.position - txo_message.error.blocked.channel.height = resolve_result.height - break + censored_hash = bytes.fromhex(row.censor_id) + if censored_hash not in censored_txos: + censored_txos[censored_hash] = [] + censored_txos[censored_hash].append((str(row), txo_message)) else: - cls.encode_txo(txo_message, row) + encode_txo(txo_message, row) + for row in extra_txo_rows: + if row.claim_hash in censored: + page.blocked_total += len(censored[row.claim_hash]) + blocked = page.blocked.add() + blocked.count = len(censored[row.claim_hash]) + blocked.channel.tx_hash = row.tx_hash + blocked.channel.nout = row.position + blocked.channel.height = row.height + if row.claim_hash in censored_txos: + for (text, txo_message) in censored_txos[row.claim_hash]: + txo_message.error.code = ErrorMessage.BLOCKED + txo_message.error.text = text + txo_message.error.blocked.channel.tx_hash = row.tx_hash + txo_message.error.blocked.channel.nout = row.position + txo_message.error.blocked.channel.height = row.height + encode_txo(page.extra_txos.add(), row) return page.SerializeToString() - - @classmethod - def encode_txo(cls, txo_message: OutputsMessage, resolve_result: Union['ResolveResult', Exception]): - if isinstance(resolve_result, Exception): - txo_message.error.text = resolve_result.args[0] - if isinstance(resolve_result, ValueError): - txo_message.error.code = ErrorMessage.INVALID - elif isinstance(resolve_result, LookupError): - txo_message.error.code = ErrorMessage.NOT_FOUND - return - txo_message.tx_hash = resolve_result.tx_hash - txo_message.nout = resolve_result.position - txo_message.height = resolve_result.height - txo_message.claim.short_url = resolve_result.short_url - txo_message.claim.reposted = resolve_result.reposted - txo_message.claim.is_controlling = resolve_result.is_controlling - txo_message.claim.creation_height = resolve_result.creation_height - txo_message.claim.activation_height = resolve_result.activation_height - txo_message.claim.expiration_height = resolve_result.expiration_height - txo_message.claim.effective_amount = resolve_result.effective_amount - txo_message.claim.support_amount = resolve_result.support_amount - - if resolve_result.canonical_url is not None: - txo_message.claim.canonical_url = resolve_result.canonical_url - if resolve_result.last_takeover_height is not None: - txo_message.claim.take_over_height = resolve_result.last_takeover_height - if resolve_result.claims_in_channel is not None: - txo_message.claim.claims_in_channel = resolve_result.claims_in_channel - if resolve_result.reposted_claim_hash and resolve_result.reposted_tx_hash is not None: - txo_message.claim.repost.tx_hash = resolve_result.reposted_tx_hash - txo_message.claim.repost.nout = resolve_result.reposted_tx_position - txo_message.claim.repost.height = resolve_result.reposted_height - if resolve_result.channel_hash and resolve_result.channel_tx_hash is not None: - txo_message.claim.channel.tx_hash = resolve_result.channel_tx_hash - txo_message.claim.channel.nout = resolve_result.channel_tx_position - txo_message.claim.channel.height = resolve_result.channel_height