Merge branch 'fix-dht-rpc-id'

This commit is contained in:
Jack Robison 2017-10-24 20:48:11 -04:00
commit 988e54602e
No known key found for this signature in database
GPG key ID: 284699E7404E3CFF
8 changed files with 62 additions and 48 deletions

View file

@ -19,6 +19,7 @@ at anytime.
* Fixed several parsing bugs that prevented replacing dead dht contacts * Fixed several parsing bugs that prevented replacing dead dht contacts
* Fixed lbryid length validation * Fixed lbryid length validation
* Fixed an old print statement that polluted logs * Fixed an old print statement that polluted logs
* Fixed rpc id length for dht requests
### Deprecated ### Deprecated
* *

View file

@ -59,3 +59,5 @@ from lbrynet.core.cryptoutils import get_lbry_hash_obj
h = get_lbry_hash_obj() h = get_lbry_hash_obj()
key_bits = h.digest_size * 8 # 384 bits key_bits = h.digest_size * 8 # 384 bits
rpc_id_length = 20

View file

@ -8,12 +8,17 @@
# may be created by processing this file with epydoc: http://epydoc.sf.net # may be created by processing this file with epydoc: http://epydoc.sf.net
from lbrynet.core.utils import generate_id from lbrynet.core.utils import generate_id
import constants
class Message(object): class Message(object):
""" Base class for messages - all "unknown" messages use this class """ """ Base class for messages - all "unknown" messages use this class """
def __init__(self, rpcID, nodeID): def __init__(self, rpcID, nodeID):
if len(rpcID) != constants.rpc_id_length:
raise ValueError("invalid rpc id: %i bytes (expected 20)" % len(rpcID))
if len(nodeID) != constants.key_bits / 8:
raise ValueError("invalid node id: %i bytes (expected 48)" % len(nodeID))
self.id = rpcID self.id = rpcID
self.nodeID = nodeID self.nodeID = nodeID
@ -23,7 +28,7 @@ class RequestMessage(Message):
def __init__(self, nodeID, method, methodArgs, rpcID=None): def __init__(self, nodeID, method, methodArgs, rpcID=None):
if rpcID is None: if rpcID is None:
rpcID = generate_id() rpcID = generate_id()[:constants.rpc_id_length]
Message.__init__(self, rpcID, nodeID) Message.__init__(self, rpcID, nodeID)
self.request = method self.request = method
self.args = methodArgs self.args = methodArgs

View file

@ -205,14 +205,14 @@ class KademliaProtocol(protocol.DatagramProtocol):
return return
try: try:
msgPrimitive = self._encoder.decode(datagram) msgPrimitive = self._encoder.decode(datagram)
except encoding.DecodeError: message = self._translator.fromPrimitive(msgPrimitive)
except (encoding.DecodeError, ValueError):
# We received some rubbish here # We received some rubbish here
return return
except IndexError: except IndexError:
log.warning("Couldn't decode dht datagram from %s", address) log.warning("Couldn't decode dht datagram from %s", address)
return return
message = self._translator.fromPrimitive(msgPrimitive)
remoteContact = Contact(message.nodeID, address[0], address[1], self) remoteContact = Contact(message.nodeID, address[0], address[1], self)
now = time.time() now = time.time()
@ -422,8 +422,10 @@ class KademliaProtocol(protocol.DatagramProtocol):
self._sentMessages[messageID] = (remoteContactID, df, timeoutCall, method, args) self._sentMessages[messageID] = (remoteContactID, df, timeoutCall, method, args)
else: else:
# No progress has been made # No progress has been made
del self._partialMessagesProgress[messageID] if messageID in self._partialMessagesProgress:
del self._partialMessages[messageID] del self._partialMessagesProgress[messageID]
if messageID in self._partialMessages:
del self._partialMessages[messageID]
df.errback(TimeoutError(remoteContactID)) df.errback(TimeoutError(remoteContactID))
def _hasProgressBeenMade(self, messageID): def _hasProgressBeenMade(self, messageID):

View file

@ -62,6 +62,7 @@ class NodeDataTest(unittest.TestCase):
self.cases.append((h.digest(), 5000+2*i)) self.cases.append((h.digest(), 5000+2*i))
self.cases.append((h.digest(), 5001+2*i)) self.cases.append((h.digest(), 5001+2*i))
@defer.inlineCallbacks
def testStore(self): def testStore(self):
""" Tests if the node can store (and privately retrieve) some data """ """ Tests if the node can store (and privately retrieve) some data """
for key, value in self.cases: for key, value in self.cases:
@ -70,7 +71,7 @@ class NodeDataTest(unittest.TestCase):
'lbryid': self.contact.id, 'lbryid': self.contact.id,
'token': self.token 'token': self.token
} }
self.node.store(key, request, self.contact.id, _rpcNodeContact=self.contact) yield self.node.store(key, request, self.contact.id, _rpcNodeContact=self.contact)
for key, value in self.cases: for key, value in self.cases:
expected_result = self.contact.compact_ip() + str(struct.pack('>H', value)) + \ expected_result = self.contact.compact_ip() + str(struct.pack('>H', value)) + \
self.contact.id self.contact.id
@ -90,7 +91,7 @@ class NodeContactTest(unittest.TestCase):
""" Tests if a contact can be added and retrieved correctly """ """ Tests if a contact can be added and retrieved correctly """
import lbrynet.dht.contact import lbrynet.dht.contact
# Create the contact # Create the contact
h = hashlib.sha1() h = hashlib.sha384()
h.update('node1') h.update('node1')
contactID = h.digest() contactID = h.digest()
contact = lbrynet.dht.contact.Contact(contactID, '127.0.0.1', 91824, self.node._protocol) contact = lbrynet.dht.contact.Contact(contactID, '127.0.0.1', 91824, self.node._protocol)
@ -133,6 +134,10 @@ class FakeRPCProtocol(protocol.DatagramProtocol):
def sendRPC(self, contact, method, args, rawResponse=False): def sendRPC(self, contact, method, args, rawResponse=False):
""" Fake RPC protocol; allows entangled.kademlia.contact.Contact objects to "send" RPCs""" """ Fake RPC protocol; allows entangled.kademlia.contact.Contact objects to "send" RPCs"""
h = hashlib.sha384()
h.update('rpcId')
rpc_id = h.digest()[:20]
if method == "findNode": if method == "findNode":
# get the specific contacts closest contacts # get the specific contacts closest contacts
closestContacts = [] closestContacts = []
@ -144,7 +149,8 @@ class FakeRPCProtocol(protocol.DatagramProtocol):
# Pack the closest contacts into a ResponseMessage # Pack the closest contacts into a ResponseMessage
for closeContact in closestContactsList: for closeContact in closestContactsList:
closestContacts.append((closeContact.id, closeContact.address, closeContact.port)) closestContacts.append((closeContact.id, closeContact.address, closeContact.port))
message = ResponseMessage("rpcId", contact.id, closestContacts)
message = ResponseMessage(rpc_id, contact.id, closestContacts)
df = defer.Deferred() df = defer.Deferred()
df.callback((message, (contact.address, contact.port))) df.callback((message, (contact.address, contact.port)))
return df return df
@ -171,7 +177,7 @@ class FakeRPCProtocol(protocol.DatagramProtocol):
response = closestContacts response = closestContacts
# Create the response message # Create the response message
message = ResponseMessage("rpcId", contact.id, response) message = ResponseMessage(rpc_id, contact.id, response)
df = defer.Deferred() df = defer.Deferred()
df.callback((message, (contact.address, contact.port))) df.callback((message, (contact.address, contact.port)))
return df return df
@ -189,7 +195,10 @@ class NodeLookupTest(unittest.TestCase):
# Note: The reactor is never started for this test. All deferred calls run sequentially, # Note: The reactor is never started for this test. All deferred calls run sequentially,
# since there is no asynchronous network communication # since there is no asynchronous network communication
# create the node to be tested in isolation # create the node to be tested in isolation
self.node = lbrynet.dht.node.Node('12345678901234567800', 4000, None, None, self._protocol) h = hashlib.sha384()
h.update('node1')
node_id = str(h.digest())
self.node = lbrynet.dht.node.Node(node_id, 4000, None, None, self._protocol)
self.updPort = 81173 self.updPort = 81173
self.contactsAmount = 80 self.contactsAmount = 80
# Reinitialise the routing table # Reinitialise the routing table
@ -198,16 +207,16 @@ class NodeLookupTest(unittest.TestCase):
# create 160 bit node ID's for test purposes # create 160 bit node ID's for test purposes
self.testNodeIDs = [] self.testNodeIDs = []
idNum = int(self.node.node_id) idNum = int(self.node.node_id.encode('hex'), 16)
for i in range(self.contactsAmount): for i in range(self.contactsAmount):
# create the testNodeIDs in ascending order, away from the actual node ID, # create the testNodeIDs in ascending order, away from the actual node ID,
# with regards to the distance metric # with regards to the distance metric
self.testNodeIDs.append(idNum + i + 1) self.testNodeIDs.append(str("%X" % (idNum + i + 1)).decode('hex'))
# generate contacts # generate contacts
self.contacts = [] self.contacts = []
for i in range(self.contactsAmount): for i in range(self.contactsAmount):
contact = lbrynet.dht.contact.Contact(str(self.testNodeIDs[i]), "127.0.0.1", contact = lbrynet.dht.contact.Contact(self.testNodeIDs[i], "127.0.0.1",
self.updPort + i + 1, self._protocol) self.updPort + i + 1, self._protocol)
self.contacts.append(contact) self.contacts.append(contact)
@ -241,23 +250,24 @@ class NodeLookupTest(unittest.TestCase):
lbrynet.dht.datastore.DictDataStore())) lbrynet.dht.datastore.DictDataStore()))
self._protocol.createNetwork(contacts_with_datastores) self._protocol.createNetwork(contacts_with_datastores)
@defer.inlineCallbacks
def testNodeBootStrap(self): def testNodeBootStrap(self):
""" Test bootstrap with the closest possible contacts """ """ Test bootstrap with the closest possible contacts """
df = self.node._iterativeFind(self.node.node_id, self.contacts[0:8]) activeContacts = yield self.node._iterativeFind(self.node.node_id, self.contacts[0:8])
# Set the expected result # Set the expected result
expectedResult = [] expectedResult = set()
for item in self.contacts[0:6]: for item in self.contacts[0:6]:
expectedResult.append(item.id) expectedResult.add(item.id)
# Get the result from the deferred # Get the result from the deferred
activeContacts = df.result
# Check the length of the active contacts # Check the length of the active contacts
self.failUnlessEqual(activeContacts.__len__(), expectedResult.__len__(), self.failUnlessEqual(activeContacts.__len__(), expectedResult.__len__(),
"More active contacts should exist, there should be %d " "More active contacts should exist, there should be %d "
"contacts" % expectedResult.__len__()) "contacts but there are %d" % (len(expectedResult),
len(activeContacts)))
# Check that the received active contacts are the same as the input contacts # Check that the received active contacts are the same as the input contacts
self.failUnlessEqual(activeContacts, expectedResult, self.failUnlessEqual({contact.id for contact in activeContacts}, expectedResult,
"Active should only contain the closest possible contacts" "Active should only contain the closest possible contacts"
" which were used as input for the boostrap") " which were used as input for the boostrap")

View file

@ -16,7 +16,7 @@ class KademliaProtocolTest(unittest.TestCase):
def setUp(self): def setUp(self):
del lbrynet.dht.protocol.reactor del lbrynet.dht.protocol.reactor
lbrynet.dht.protocol.reactor = twisted.internet.selectreactor.SelectReactor() lbrynet.dht.protocol.reactor = twisted.internet.selectreactor.SelectReactor()
self.node = Node(node_id='node1', udpPort=9182, externalIP="127.0.0.1") self.node = Node(node_id='1' * 48, udpPort=9182, externalIP="127.0.0.1")
self.protocol = lbrynet.dht.protocol.KademliaProtocol(self.node) self.protocol = lbrynet.dht.protocol.KademliaProtocol(self.node)
def testReactor(self): def testReactor(self):
@ -39,7 +39,7 @@ class KademliaProtocolTest(unittest.TestCase):
lbrynet.dht.constants.rpcAttempts = 1 lbrynet.dht.constants.rpcAttempts = 1
lbrynet.dht.constants.rpcTimeout = 1 lbrynet.dht.constants.rpcTimeout = 1
self.node.ping = fake_ping self.node.ping = fake_ping
deadContact = lbrynet.dht.contact.Contact('node2', '127.0.0.1', 9182, self.protocol) deadContact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.1', 9182, self.protocol)
self.node.addContact(deadContact) self.node.addContact(deadContact)
# Make sure the contact was added # Make sure the contact was added
self.failIf(deadContact not in self.node.contacts, self.failIf(deadContact not in self.node.contacts,
@ -73,7 +73,7 @@ class KademliaProtocolTest(unittest.TestCase):
def testRPCRequest(self): def testRPCRequest(self):
""" Tests if a valid RPC request is executed and responded to correctly """ """ Tests if a valid RPC request is executed and responded to correctly """
remoteContact = lbrynet.dht.contact.Contact('node2', '127.0.0.1', 9182, self.protocol) remoteContact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.1', 9182, self.protocol)
self.node.addContact(remoteContact) self.node.addContact(remoteContact)
self.error = None self.error = None
@ -105,7 +105,7 @@ class KademliaProtocolTest(unittest.TestCase):
Verifies that a RPC request for an existing but unpublished Verifies that a RPC request for an existing but unpublished
method is denied, and that the associated (remote) exception gets method is denied, and that the associated (remote) exception gets
raised locally """ raised locally """
remoteContact = lbrynet.dht.contact.Contact('node2', '127.0.0.1', 9182, self.protocol) remoteContact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.1', 9182, self.protocol)
self.node.addContact(remoteContact) self.node.addContact(remoteContact)
self.error = None self.error = None
@ -126,7 +126,7 @@ class KademliaProtocolTest(unittest.TestCase):
# Publish the "local" node on the network # Publish the "local" node on the network
lbrynet.dht.protocol.reactor.listenUDP(9182, self.protocol) lbrynet.dht.protocol.reactor.listenUDP(9182, self.protocol)
# Simulate the RPC # Simulate the RPC
df = remoteContact.pingNoRPC() df = remoteContact.not_a_rpc_function()
df.addCallback(handleResult) df.addCallback(handleResult)
df.addErrback(handleError) df.addErrback(handleError)
df.addBoth(lambda _: lbrynet.dht.protocol.reactor.stop()) df.addBoth(lambda _: lbrynet.dht.protocol.reactor.stop())
@ -139,7 +139,7 @@ class KademliaProtocolTest(unittest.TestCase):
def testRPCRequestArgs(self): def testRPCRequestArgs(self):
""" Tests if an RPC requiring arguments is executed correctly """ """ Tests if an RPC requiring arguments is executed correctly """
remoteContact = lbrynet.dht.contact.Contact('node2', '127.0.0.1', 9182, self.protocol) remoteContact = lbrynet.dht.contact.Contact('2' * 48, '127.0.0.1', 9182, self.protocol)
self.node.addContact(remoteContact) self.node.addContact(remoteContact)
self.error = None self.error = None

View file

@ -27,6 +27,9 @@ class FakeDeferred(object):
def addErrback(self, *args, **kwargs): def addErrback(self, *args, **kwargs):
return return
def addCallbacks(self, *args, **kwargs):
return
class TreeRoutingTableTest(unittest.TestCase): class TreeRoutingTableTest(unittest.TestCase):
""" Test case for the RoutingTable class """ """ Test case for the RoutingTable class """

View file

@ -9,40 +9,41 @@ import unittest
from lbrynet.dht.msgtypes import RequestMessage, ResponseMessage, ErrorMessage from lbrynet.dht.msgtypes import RequestMessage, ResponseMessage, ErrorMessage
from lbrynet.dht.msgformat import MessageTranslator, DefaultFormat from lbrynet.dht.msgformat import MessageTranslator, DefaultFormat
class DefaultFormatTranslatorTest(unittest.TestCase): class DefaultFormatTranslatorTest(unittest.TestCase):
""" Test case for the default message translator """ """ Test case for the default message translator """
def setUp(self): def setUp(self):
self.cases = ((RequestMessage('node1', 'rpcMethod', self.cases = ((RequestMessage('1' * 48, 'rpcMethod',
{'arg1': 'a string', 'arg2': 123}, 'rpc1'), {'arg1': 'a string', 'arg2': 123}, '1' * 20),
{DefaultFormat.headerType: DefaultFormat.typeRequest, {DefaultFormat.headerType: DefaultFormat.typeRequest,
DefaultFormat.headerNodeID: 'node1', DefaultFormat.headerNodeID: '1' * 48,
DefaultFormat.headerMsgID: 'rpc1', DefaultFormat.headerMsgID: '1' * 20,
DefaultFormat.headerPayload: 'rpcMethod', DefaultFormat.headerPayload: 'rpcMethod',
DefaultFormat.headerArgs: {'arg1': 'a string', 'arg2': 123}}), DefaultFormat.headerArgs: {'arg1': 'a string', 'arg2': 123}}),
(ResponseMessage('rpc2', 'node2', 'response'), (ResponseMessage('2' * 20, '2' * 48, 'response'),
{DefaultFormat.headerType: DefaultFormat.typeResponse, {DefaultFormat.headerType: DefaultFormat.typeResponse,
DefaultFormat.headerNodeID: 'node2', DefaultFormat.headerNodeID: '2' * 48,
DefaultFormat.headerMsgID: 'rpc2', DefaultFormat.headerMsgID: '2' * 20,
DefaultFormat.headerPayload: 'response'}), DefaultFormat.headerPayload: 'response'}),
(ErrorMessage('rpc3', 'node3', (ErrorMessage('3' * 20, '3' * 48,
"<type 'exceptions.ValueError'>", 'this is a test exception'), "<type 'exceptions.ValueError'>", 'this is a test exception'),
{DefaultFormat.headerType: DefaultFormat.typeError, {DefaultFormat.headerType: DefaultFormat.typeError,
DefaultFormat.headerNodeID: 'node3', DefaultFormat.headerNodeID: '3' * 48,
DefaultFormat.headerMsgID: 'rpc3', DefaultFormat.headerMsgID: '3' * 20,
DefaultFormat.headerPayload: "<type 'exceptions.ValueError'>", DefaultFormat.headerPayload: "<type 'exceptions.ValueError'>",
DefaultFormat.headerArgs: 'this is a test exception'}), DefaultFormat.headerArgs: 'this is a test exception'}),
(ResponseMessage( (ResponseMessage(
'rpc4', 'node4', '4' * 20, '4' * 48,
[('H\x89\xb0\xf4\xc9\xe6\xc5`H>\xd5\xc2\xc5\xe8Od\xf1\xca\xfa\x82', [('H\x89\xb0\xf4\xc9\xe6\xc5`H>\xd5\xc2\xc5\xe8Od\xf1\xca\xfa\x82',
'127.0.0.1', 1919), '127.0.0.1', 1919),
('\xae\x9ey\x93\xdd\xeb\xf1^\xff\xc5\x0f\xf8\xac!\x0e\x03\x9fY@{', ('\xae\x9ey\x93\xdd\xeb\xf1^\xff\xc5\x0f\xf8\xac!\x0e\x03\x9fY@{',
'127.0.0.1', 1921)]), '127.0.0.1', 1921)]),
{DefaultFormat.headerType: DefaultFormat.typeResponse, {DefaultFormat.headerType: DefaultFormat.typeResponse,
DefaultFormat.headerNodeID: 'node4', DefaultFormat.headerNodeID: '4' * 48,
DefaultFormat.headerMsgID: 'rpc4', DefaultFormat.headerMsgID: '4' * 20,
DefaultFormat.headerPayload: DefaultFormat.headerPayload:
[('H\x89\xb0\xf4\xc9\xe6\xc5`H>\xd5\xc2\xc5\xe8Od\xf1\xca\xfa\x82', [('H\x89\xb0\xf4\xc9\xe6\xc5`H>\xd5\xc2\xc5\xe8Od\xf1\xca\xfa\x82',
'127.0.0.1', 1919), '127.0.0.1', 1919),
@ -81,13 +82,3 @@ class DefaultFormatTranslatorTest(unittest.TestCase):
'Message instance variable "%s" not translated correctly; ' 'Message instance variable "%s" not translated correctly; '
'expected "%s", got "%s"' % 'expected "%s", got "%s"' %
(key, msg.__dict__[key], translatedObj.__dict__[key])) (key, msg.__dict__[key], translatedObj.__dict__[key]))
def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(DefaultFormatTranslatorTest))
return suite
if __name__ == '__main__':
# If this module is executed from the commandline, run all its tests
unittest.TextTestRunner().run(suite())