added unit test for Access-Control HTTP headers

This commit is contained in:
Lex Berezhny 2021-04-06 15:19:34 -04:00
parent b97164fcfb
commit c8781392be
2 changed files with 44 additions and 2 deletions

View file

@ -542,14 +542,13 @@ class Daemon(metaclass=JSONRPCServerType):
async def add_cors_headers(self, request): async def add_cors_headers(self, request):
if self.conf.allowed_origin: if self.conf.allowed_origin:
response = web.Response( return web.Response(
headers={ headers={
'Access-Control-Allow-Origin': self.conf.allowed_origin, 'Access-Control-Allow-Origin': self.conf.allowed_origin,
'Access-Control-Allow-Methods': self.conf.allowed_origin, 'Access-Control-Allow-Methods': self.conf.allowed_origin,
'Access-Control-Allow-Headers': self.conf.allowed_origin, 'Access-Control-Allow-Headers': self.conf.allowed_origin,
} }
) )
return response
return None return None
async def handle_old_jsonrpc(self, request): async def handle_old_jsonrpc(self, request):

View file

@ -1,11 +1,19 @@
import unittest import unittest
from aiohttp import ClientSession
from aiohttp.test_utils import make_mocked_request as request from aiohttp.test_utils import make_mocked_request as request
from aiohttp.web import HTTPForbidden from aiohttp.web import HTTPForbidden
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.conf import Config from lbry.conf import Config
from lbry.extras.daemon.security import is_request_allowed as allowed, ensure_request_allowed as ensure from lbry.extras.daemon.security import is_request_allowed as allowed, ensure_request_allowed as ensure
from lbry.extras.daemon.components import (
DATABASE_COMPONENT, BLOB_COMPONENT, WALLET_COMPONENT, DHT_COMPONENT,
HASH_ANNOUNCER_COMPONENT, FILE_MANAGER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT,
UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, WALLET_SERVER_PAYMENTS_COMPONENT,
LIBTORRENT_COMPONENT
)
from lbry.extras.daemon.daemon import Daemon
class TestAllowedOrigin(unittest.TestCase): class TestAllowedOrigin(unittest.TestCase):
@ -51,3 +59,38 @@ class TestAllowedOrigin(unittest.TestCase):
ensure(request('GET', '/', headers={'Origin': 'hackers.com'}), conf) ensure(request('GET', '/', headers={'Origin': 'hackers.com'}), conf)
self.assertIn("'hackers.com' are not allowed", log.output[0]) self.assertIn("'hackers.com' are not allowed", log.output[0])
self.assertIn("'allowed_origin' limits requests to: 'localhost'", log.output[0]) self.assertIn("'allowed_origin' limits requests to: 'localhost'", log.output[0])
class TestAccessHeaders(AsyncioTestCase):
async def asyncSetUp(self):
conf = Config(allowed_origin='localhost')
conf.data_dir = '/tmp'
conf.share_usage_data = False
conf.api = 'localhost:5299'
conf.components_to_skip = (
DATABASE_COMPONENT, BLOB_COMPONENT, WALLET_COMPONENT, DHT_COMPONENT,
HASH_ANNOUNCER_COMPONENT, FILE_MANAGER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT,
UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, WALLET_SERVER_PAYMENTS_COMPONENT,
LIBTORRENT_COMPONENT
)
Daemon.component_attributes = {}
self.daemon = Daemon(conf)
await self.daemon.start()
self.addCleanup(self.daemon.stop)
async def test_headers(self):
async with ClientSession() as session:
# OPTIONS
async with session.options('http://localhost:5299') as resp:
self.assertEqual(resp.headers['Access-Control-Allow-Origin'], 'localhost')
self.assertEqual(resp.headers['Access-Control-Allow-Methods'], 'localhost')
self.assertEqual(resp.headers['Access-Control-Allow-Headers'], 'localhost')
# GET
status = {'method': 'status', 'params': []}
async with session.get('http://localhost:5299/lbryapi', json=status) as resp:
self.assertEqual(resp.headers['Access-Control-Allow-Origin'], 'localhost')
self.assertEqual(resp.headers['Access-Control-Allow-Methods'], 'localhost')
self.assertEqual(resp.headers['Access-Control-Allow-Headers'], 'localhost')