diff --git a/electrum/daemon.py b/electrum/daemon.py index 7d26c8873..e2a227476 100644 --- a/electrum/daemon.py +++ b/electrum/daemon.py @@ -29,16 +29,16 @@ import time import traceback import sys import threading -from typing import Dict, Optional, Tuple, Iterable +from typing import Dict, Optional, Tuple, Iterable, Callable, Union, Sequence, Mapping from base64 import b64decode, b64encode from collections import defaultdict import concurrent from concurrent import futures +import json import aiohttp from aiohttp import web, client_exceptions from aiorpcx import TaskGroup -import json from . import util from .network import Network @@ -151,6 +151,11 @@ class AuthenticatedServer(Logger): self.rpc_user = rpc_user self.rpc_password = rpc_password self.auth_lock = asyncio.Lock() + self._methods = {} # type: Dict[str, Callable] + + def register_method(self, f): + assert f.__name__ not in self._methods, f"name collision for {f.__name__}" + self._methods[f.__name__] = f async def authenticate(self, headers): if self.rpc_password == '': @@ -184,15 +189,21 @@ class AuthenticatedServer(Logger): request = json.loads(request) method = request['method'] _id = request['id'] - params = request.get('params', []) - f = getattr(self, method) - assert f in self.methods - except: + params = request.get('params', []) # type: Union[Sequence, Mapping] + if method not in self._methods: + raise Exception(f"attempting to use unregistered method: {method}") + f = self._methods[method] + except Exception as e: + self.logger.exception("invalid request") return web.Response(text='Invalid Request', status=500) - response = {'id':_id} + response = {'id': _id} try: - response['result'] = await f(*params) + if isinstance(params, dict): + response['result'] = await f(**params) + else: + response['result'] = await f(*params) except BaseException as e: + self.logger.exception("internal error while executing RPC") response['error'] = str(e) return web.json_response(response) @@ -209,13 +220,12 @@ class CommandsServer(AuthenticatedServer): self.port = self.config.get('rpcport', 0) self.app = web.Application() self.app.router.add_post("/", self.handle) - self.methods = set() - self.methods.add(self.ping) - self.methods.add(self.gui) + self.register_method(self.ping) + self.register_method(self.gui) self.cmd_runner = Commands(config=self.config, network=self.daemon.network, daemon=self.daemon) for cmdname in known_commands: - self.methods.add(getattr(self.cmd_runner, cmdname)) - self.methods.add(self.run_cmdline) + self.register_method(getattr(self.cmd_runner, cmdname)) + self.register_method(self.run_cmdline) async def run(self): self.runner = web.AppRunner(self.app) @@ -277,9 +287,8 @@ class WatchTowerServer(AuthenticatedServer): self.lnwatcher = network.local_watchtower self.app = web.Application() self.app.router.add_post("/", self.handle) - self.methods = set() - self.methods.add(self.get_ctn) - self.methods.add(self.add_sweep_tx) + self.register_method(self.get_ctn) + self.register_method(self.add_sweep_tx) async def run(self): self.runner = web.AppRunner(self.app)