working functional test_contact_rpc + more string bans

This commit is contained in:
Victor Shyba 2018-07-20 16:45:58 -03:00 committed by Jack Robison
parent 1ee682f06f
commit e1314a9d1e
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
8 changed files with 61 additions and 62 deletions

View file

@ -59,7 +59,7 @@ class _Contact(object):
def log_id(self, short=True): def log_id(self, short=True):
if not self.id: if not self.id:
return "not initialized" return "not initialized"
id_hex = self.id.encode('hex') id_hex = hexlify(self.id)
return id_hex if not short else id_hex[:8] return id_hex if not short else id_hex[:8]
@property @property
@ -162,7 +162,7 @@ class _Contact(object):
raise AttributeError("unknown command: %s" % name) raise AttributeError("unknown command: %s" % name)
def _sendRPC(*args, **kwargs): def _sendRPC(*args, **kwargs):
return self._networkProtocol.sendRPC(self, name, args) return self._networkProtocol.sendRPC(self, name.encode(), args)
return _sendRPC return _sendRPC

View file

@ -58,8 +58,6 @@ class Bencode(Encoding):
""" """
if isinstance(data, (int, long)): if isinstance(data, (int, long)):
return b'i%de' % data return b'i%de' % data
elif isinstance(data, str):
return b'%d:%s' % (len(data), data.encode())
elif isinstance(data, bytes): elif isinstance(data, bytes):
return b'%d:%s' % (len(data), data) return b'%d:%s' % (len(data), data)
elif isinstance(data, (list, tuple)): elif isinstance(data, (list, tuple)):

View file

@ -140,8 +140,7 @@ class KBucket(object):
if not. if not.
@rtype: bool @rtype: bool
""" """
if isinstance(key, str): assert type(key) in [long, bytes], "{} is {}".format(key, type(key)) # fixme: _maybe_ remove this after porting
key = long(hexlify(key.encode()), 16)
if isinstance(key, bytes): if isinstance(key, bytes):
key = long(hexlify(key), 16) key = long(hexlify(key), 16)
return self.rangeMin <= key < self.rangeMax return self.rangeMin <= key < self.rangeMax

View file

@ -48,6 +48,5 @@ class ErrorMessage(ResponseMessage):
def __init__(self, rpcID, nodeID, exceptionType, errorMessage): def __init__(self, rpcID, nodeID, exceptionType, errorMessage):
ResponseMessage.__init__(self, rpcID, nodeID, errorMessage) ResponseMessage.__init__(self, rpcID, nodeID, errorMessage)
if isinstance(exceptionType, type): if isinstance(exceptionType, type):
self.exceptionType = '%s.%s' % (exceptionType.__module__, exceptionType.__name__) exceptionType = ('%s.%s' % (exceptionType.__module__, exceptionType.__name__)).encode()
else: self.exceptionType = exceptionType
self.exceptionType = exceptionType

View file

@ -10,6 +10,8 @@ import binascii
import hashlib import hashlib
import struct import struct
import logging import logging
from functools import reduce
from twisted.internet import defer, error, task from twisted.internet import defer, error, task
from lbrynet.core.utils import generate_id, DeferredDict from lbrynet.core.utils import generate_id, DeferredDict
@ -493,7 +495,7 @@ class Node(MockKademliaHelper):
@rtype: str @rtype: str
""" """
return 'pong' return b'pong'
@rpcmethod @rpcmethod
def store(self, rpc_contact, blob_hash, token, port, originalPublisherID, age): def store(self, rpc_contact, blob_hash, token, port, originalPublisherID, age):
@ -530,13 +532,13 @@ class Node(MockKademliaHelper):
if 0 <= port <= 65536: if 0 <= port <= 65536:
compact_port = struct.pack('>H', port) compact_port = struct.pack('>H', port)
else: else:
raise TypeError('Invalid port') raise TypeError('Invalid port: {}'.format(port))
compact_address = compact_ip + compact_port + rpc_contact.id compact_address = compact_ip + compact_port + rpc_contact.id
now = int(self.clock.seconds()) now = int(self.clock.seconds())
originallyPublished = now - age originallyPublished = now - age
self._dataStore.addPeerToBlob(rpc_contact, blob_hash, compact_address, now, originallyPublished, self._dataStore.addPeerToBlob(rpc_contact, blob_hash, compact_address, now, originallyPublished,
originalPublisherID) originalPublisherID)
return 'OK' return b'OK'
@rpcmethod @rpcmethod
def findNode(self, rpc_contact, key): def findNode(self, rpc_contact, key):
@ -578,11 +580,11 @@ class Node(MockKademliaHelper):
raise ValueError("invalid blob hash length: %i" % len(key)) raise ValueError("invalid blob hash length: %i" % len(key))
response = { response = {
'token': self.make_token(rpc_contact.compact_ip()), b'token': self.make_token(rpc_contact.compact_ip()),
} }
if self._protocol._protocolVersion: if self._protocol._protocolVersion:
response['protocolVersion'] = self._protocol._protocolVersion response[b'protocolVersion'] = self._protocol._protocolVersion
# get peers we have stored for this blob # get peers we have stored for this blob
has_other_peers = self._dataStore.hasPeersForBlob(key) has_other_peers = self._dataStore.hasPeersForBlob(key)
@ -592,17 +594,15 @@ class Node(MockKademliaHelper):
# if we don't have k storing peers to return and we have this hash locally, include our contact information # if we don't have k storing peers to return and we have this hash locally, include our contact information
if len(peers) < constants.k and key in self._dataStore.completed_blobs: if len(peers) < constants.k and key in self._dataStore.completed_blobs:
compact_ip = str( compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), self.externalIP.split('.'), bytearray())
reduce(lambda buff, x: buff + bytearray([int(x)]), self.externalIP.split('.'), bytearray()) compact_port = struct.pack('>H', self.peerPort)
)
compact_port = str(struct.pack('>H', self.peerPort))
compact_address = compact_ip + compact_port + self.node_id compact_address = compact_ip + compact_port + self.node_id
peers.append(compact_address) peers.append(compact_address)
if peers: if peers:
response[key] = peers response[key] = peers
else: else:
response['contacts'] = self.findNode(rpc_contact, key) response[b'contacts'] = self.findNode(rpc_contact, key)
return response return response
def _generateID(self): def _generateID(self):

View file

@ -1,6 +1,7 @@
import logging import logging
import socket import socket
import errno import errno
from binascii import hexlify
from collections import deque from collections import deque
from twisted.internet import protocol, defer from twisted.internet import protocol, defer
@ -108,12 +109,12 @@ class KademliaProtocol(protocol.DatagramProtocol):
self.started_listening_time = 0 self.started_listening_time = 0
def _migrate_incoming_rpc_args(self, contact, method, *args): def _migrate_incoming_rpc_args(self, contact, method, *args):
if method == 'store' and contact.protocolVersion == 0: if method == b'store' and contact.protocolVersion == 0:
if isinstance(args[1], dict): if isinstance(args[1], dict):
blob_hash = args[0] blob_hash = args[0]
token = args[1].pop('token', None) token = args[1].pop(b'token', None)
port = args[1].pop('port', -1) port = args[1].pop(b'port', -1)
originalPublisherID = args[1].pop('lbryid', None) originalPublisherID = args[1].pop(b'lbryid', None)
age = 0 age = 0
return (blob_hash, token, port, originalPublisherID, age), {} return (blob_hash, token, port, originalPublisherID, age), {}
return args, {} return args, {}
@ -124,16 +125,16 @@ class KademliaProtocol(protocol.DatagramProtocol):
protocol version keyword argument to calls to contacts who will accept it protocol version keyword argument to calls to contacts who will accept it
""" """
if contact.protocolVersion == 0: if contact.protocolVersion == 0:
if method == 'store': if method == b'store':
blob_hash, token, port, originalPublisherID, age = args blob_hash, token, port, originalPublisherID, age = args
args = (blob_hash, {'token': token, 'port': port, 'lbryid': originalPublisherID}, originalPublisherID, args = (blob_hash, {b'token': token, b'port': port, b'lbryid': originalPublisherID}, originalPublisherID,
False) False)
return args return args
return args return args
if args and isinstance(args[-1], dict): if args and isinstance(args[-1], dict):
args[-1]['protocolVersion'] = self._protocolVersion args[-1][b'protocolVersion'] = self._protocolVersion
return args return args
return args + ({'protocolVersion': self._protocolVersion},) return args + ({b'protocolVersion': self._protocolVersion},)
def sendRPC(self, contact, method, args): def sendRPC(self, contact, method, args):
""" """
@ -162,7 +163,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
if args: if args:
log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method, log.debug("%s:%i SEND CALL %s(%s) TO %s:%i", self._node.externalIP, self._node.port, method,
args[0].encode('hex'), contact.address, contact.port) hexlify(args[0]), contact.address, contact.port)
else: else:
log.debug("%s:%i SEND CALL %s TO %s:%i", self._node.externalIP, self._node.port, method, log.debug("%s:%i SEND CALL %s TO %s:%i", self._node.externalIP, self._node.port, method,
contact.address, contact.port) contact.address, contact.port)
@ -179,11 +180,11 @@ class KademliaProtocol(protocol.DatagramProtocol):
def _update_contact(result): # refresh the contact in the routing table def _update_contact(result): # refresh the contact in the routing table
contact.update_last_replied() contact.update_last_replied()
if method == 'findValue': if method == b'findValue':
if 'protocolVersion' not in result: if b'protocolVersion' not in result:
contact.update_protocol_version(0) contact.update_protocol_version(0)
else: else:
contact.update_protocol_version(result.pop('protocolVersion')) contact.update_protocol_version(result.pop(b'protocolVersion'))
d = self._node.addContact(contact) d = self._node.addContact(contact)
d.addCallback(lambda _: result) d.addCallback(lambda _: result)
return d return d
@ -214,8 +215,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
@note: This is automatically called by Twisted when the protocol @note: This is automatically called by Twisted when the protocol
receives a UDP datagram receives a UDP datagram
""" """
if datagram[0] == b'\x00' and datagram[25] == b'\x00':
if datagram[0] == '\x00' and datagram[25] == '\x00':
totalPackets = (ord(datagram[1]) << 8) | ord(datagram[2]) totalPackets = (ord(datagram[1]) << 8) | ord(datagram[2])
msgID = datagram[5:25] msgID = datagram[5:25]
seqNumber = (ord(datagram[3]) << 8) | ord(datagram[4]) seqNumber = (ord(datagram[3]) << 8) | ord(datagram[4])
@ -307,7 +307,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
# the node id of the node we sent a message to (these messages are treated as an error) # the node id of the node we sent a message to (these messages are treated as an error)
if remoteContact.id and remoteContact.id != message.nodeID: # sent_to_id will be None for bootstrap if remoteContact.id and remoteContact.id != message.nodeID: # sent_to_id will be None for bootstrap
log.debug("mismatch: (%s) %s:%i (%s vs %s)", method, remoteContact.address, remoteContact.port, log.debug("mismatch: (%s) %s:%i (%s vs %s)", method, remoteContact.address, remoteContact.port,
remoteContact.log_id(False), message.nodeID.encode('hex')) remoteContact.log_id(False), hexlify(message.nodeID))
df.errback(TimeoutError(remoteContact.id)) df.errback(TimeoutError(remoteContact.id))
return return
elif not remoteContact.id: elif not remoteContact.id:
@ -396,6 +396,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
def _sendError(self, contact, rpcID, exceptionType, exceptionMessage): def _sendError(self, contact, rpcID, exceptionType, exceptionMessage):
""" Send an RPC error message to the specified contact """ Send an RPC error message to the specified contact
""" """
exceptionType, exceptionMessage = exceptionType.encode(), exceptionMessage.encode()
msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage) msg = msgtypes.ErrorMessage(rpcID, self._node.node_id, exceptionType, exceptionMessage)
msgPrimitive = self._translator.toPrimitive(msg) msgPrimitive = self._translator.toPrimitive(msg)
encodedMsg = self._encoder.encode(msgPrimitive) encodedMsg = self._encoder.encode(msgPrimitive)
@ -416,7 +417,7 @@ class KademliaProtocol(protocol.DatagramProtocol):
df.addErrback(handleError) df.addErrback(handleError)
# Execute the RPC # Execute the RPC
func = getattr(self._node, method, None) func = getattr(self._node, method.decode(), None)
if callable(func) and hasattr(func, "rpcmethod"): if callable(func) and hasattr(func, "rpcmethod"):
# Call the exposed Node method and return the result to the deferred callback chain # Call the exposed Node method and return the result to the deferred callback chain
# if args: # if args:
@ -425,14 +426,14 @@ class KademliaProtocol(protocol.DatagramProtocol):
# else: # else:
log.debug("%s:%i RECV CALL %s %s:%i", self._node.externalIP, self._node.port, method, log.debug("%s:%i RECV CALL %s %s:%i", self._node.externalIP, self._node.port, method,
senderContact.address, senderContact.port) senderContact.address, senderContact.port)
if args and isinstance(args[-1], dict) and 'protocolVersion' in args[-1]: # args don't need reformatting if args and isinstance(args[-1], dict) and b'protocolVersion' in args[-1]: # args don't need reformatting
senderContact.update_protocol_version(int(args[-1].pop('protocolVersion'))) senderContact.update_protocol_version(int(args[-1].pop(b'protocolVersion')))
a, kw = tuple(args[:-1]), args[-1] a, kw = tuple(args[:-1]), args[-1]
else: else:
senderContact.update_protocol_version(0) senderContact.update_protocol_version(0)
a, kw = self._migrate_incoming_rpc_args(senderContact, method, *args) a, kw = self._migrate_incoming_rpc_args(senderContact, method, *args)
try: try:
if method != 'ping': if method != b'ping':
result = func(senderContact, *a) result = func(senderContact, *a)
else: else:
result = func() result = func()

View file

@ -1,3 +1,5 @@
from binascii import unhexlify
import time import time
from twisted.trial import unittest from twisted.trial import unittest
import logging import logging
@ -19,12 +21,12 @@ class KademliaProtocolTest(unittest.TestCase):
def setUp(self): def setUp(self):
self._reactor = Clock() self._reactor = Clock()
self.node = Node(node_id='1' * 48, udpPort=self.udpPort, externalIP="127.0.0.1", listenUDP=listenUDP, self.node = Node(node_id=b'1' * 48, udpPort=self.udpPort, externalIP="127.0.0.1", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater) resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
self.remote_node = Node(node_id='2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP, self.remote_node = Node(node_id=b'2' * 48, udpPort=self.udpPort, externalIP="127.0.0.2", listenUDP=listenUDP,
resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater) resolve=resolve, clock=self._reactor, callLater=self._reactor.callLater)
self.remote_contact = self.node.contact_manager.make_contact('2' * 48, '127.0.0.2', 9182, self.node._protocol) self.remote_contact = self.node.contact_manager.make_contact(b'2' * 48, '127.0.0.2', 9182, self.node._protocol)
self.us_from_them = self.remote_node.contact_manager.make_contact('1' * 48, '127.0.0.1', 9182, self.us_from_them = self.remote_node.contact_manager.make_contact(b'1' * 48, '127.0.0.1', 9182,
self.remote_node._protocol) self.remote_node._protocol)
self.node.start_listening() self.node.start_listening()
self.remote_node.start_listening() self.remote_node.start_listening()
@ -105,7 +107,7 @@ class KademliaProtocolTest(unittest.TestCase):
self.error = 'An RPC error occurred: %s' % f.getErrorMessage() self.error = 'An RPC error occurred: %s' % f.getErrorMessage()
def handleResult(result): def handleResult(result):
expectedResult = 'pong' expectedResult = b'pong'
if result != expectedResult: if result != expectedResult:
self.error = 'Result from RPC is incorrect; expected "%s", got "%s"' \ self.error = 'Result from RPC is incorrect; expected "%s", got "%s"' \
% (expectedResult, result) % (expectedResult, result)
@ -142,7 +144,7 @@ class KademliaProtocolTest(unittest.TestCase):
self.error = 'An RPC error occurred: %s' % f.getErrorMessage() self.error = 'An RPC error occurred: %s' % f.getErrorMessage()
def handleResult(result): def handleResult(result):
expectedResult = 'pong' expectedResult = b'pong'
if result != expectedResult: if result != expectedResult:
self.error = 'Result from RPC is incorrect; expected "%s", got "%s"' % \ self.error = 'Result from RPC is incorrect; expected "%s", got "%s"' % \
(expectedResult, result) (expectedResult, result)
@ -163,12 +165,12 @@ class KademliaProtocolTest(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def testDetectProtocolVersion(self): def testDetectProtocolVersion(self):
original_findvalue = self.remote_node.findValue original_findvalue = self.remote_node.findValue
fake_blob = str("AB" * 48).decode('hex') fake_blob = unhexlify("AB" * 48)
@rpcmethod @rpcmethod
def findValue(contact, key): def findValue(contact, key):
result = original_findvalue(contact, key) result = original_findvalue(contact, key)
result.pop('protocolVersion') result.pop(b'protocolVersion')
return result return result
self.remote_node.findValue = findValue self.remote_node.findValue = findValue
@ -205,35 +207,35 @@ class KademliaProtocolTest(unittest.TestCase):
@rpcmethod @rpcmethod
def findValue(contact, key): def findValue(contact, key):
result = original_findvalue(contact, key) result = original_findvalue(contact, key)
if 'protocolVersion' in result: if b'protocolVersion' in result:
result.pop('protocolVersion') result.pop(b'protocolVersion')
return result return result
@rpcmethod @rpcmethod
def store(contact, key, value, originalPublisherID=None, self_store=False, **kwargs): def store(contact, key, value, originalPublisherID=None, self_store=False, **kwargs):
self.assertTrue(len(key) == 48) self.assertTrue(len(key) == 48)
self.assertSetEqual(set(value.keys()), {'token', 'lbryid', 'port'}) self.assertSetEqual(set(value.keys()), {b'token', b'lbryid', b'port'})
self.assertFalse(self_store) self.assertFalse(self_store)
self.assertDictEqual(kwargs, {}) self.assertDictEqual(kwargs, {})
return original_store( # pylint: disable=too-many-function-args return original_store( # pylint: disable=too-many-function-args
contact, key, value['token'], value['port'], originalPublisherID, 0 contact, key, value[b'token'], value[b'port'], originalPublisherID, 0
) )
self.remote_node.findValue = findValue self.remote_node.findValue = findValue
self.remote_node.store = store self.remote_node.store = store
fake_blob = str("AB" * 48).decode('hex') fake_blob = unhexlify("AB" * 48)
d = self.remote_contact.findValue(fake_blob) d = self.remote_contact.findValue(fake_blob)
self._reactor.advance(3) self._reactor.advance(3)
find_value_response = yield d find_value_response = yield d
self.assertEqual(self.remote_contact.protocolVersion, 0) self.assertEqual(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response) self.assertTrue(b'protocolVersion' not in find_value_response)
token = find_value_response['token'] token = find_value_response[b'token']
d = self.remote_contact.store(fake_blob, token, 3333, self.node.node_id, 0) d = self.remote_contact.store(fake_blob, token, 3333, self.node.node_id, 0)
self._reactor.advance(3) self._reactor.advance(3)
response = yield d response = yield d
self.assertEqual(response, "OK") self.assertEqual(response, b'OK')
self.assertEqual(self.remote_contact.protocolVersion, 0) self.assertEqual(self.remote_contact.protocolVersion, 0)
self.assertTrue(self.remote_node._dataStore.hasPeersForBlob(fake_blob)) self.assertTrue(self.remote_node._dataStore.hasPeersForBlob(fake_blob))
self.assertEqual(len(self.remote_node._dataStore.getStoringContacts()), 1) self.assertEqual(len(self.remote_node._dataStore.getStoringContacts()), 1)
@ -245,24 +247,24 @@ class KademliaProtocolTest(unittest.TestCase):
self.remote_node._protocol._migrate_outgoing_rpc_args = _dont_migrate self.remote_node._protocol._migrate_outgoing_rpc_args = _dont_migrate
us_from_them = self.remote_node.contact_manager.make_contact('1' * 48, '127.0.0.1', self.udpPort, us_from_them = self.remote_node.contact_manager.make_contact(b'1' * 48, '127.0.0.1', self.udpPort,
self.remote_node._protocol) self.remote_node._protocol)
fake_blob = str("AB" * 48).decode('hex') fake_blob = unhexlify("AB" * 48)
d = us_from_them.findValue(fake_blob) d = us_from_them.findValue(fake_blob)
self._reactor.advance(3) self._reactor.advance(3)
find_value_response = yield d find_value_response = yield d
self.assertEqual(self.remote_contact.protocolVersion, 0) self.assertEqual(self.remote_contact.protocolVersion, 0)
self.assertTrue('protocolVersion' not in find_value_response) self.assertTrue(b'protocolVersion' not in find_value_response)
token = find_value_response['token'] token = find_value_response[b'token']
us_from_them.update_protocol_version(0) us_from_them.update_protocol_version(0)
d = self.remote_node._protocol.sendRPC( d = self.remote_node._protocol.sendRPC(
us_from_them, "store", (fake_blob, {'lbryid': self.remote_node.node_id, 'token': token, 'port': 3333}) us_from_them, b"store", (fake_blob, {b'lbryid': self.remote_node.node_id, b'token': token, b'port': 3333})
) )
self._reactor.advance(3) self._reactor.advance(3)
response = yield d response = yield d
self.assertEqual(response, "OK") self.assertEqual(response, b'OK')
self.assertEqual(self.remote_contact.protocolVersion, 0) self.assertEqual(self.remote_contact.protocolVersion, 0)
self.assertTrue(self.node._dataStore.hasPeersForBlob(fake_blob)) self.assertTrue(self.node._dataStore.hasPeersForBlob(fake_blob))
self.assertEqual(len(self.node._dataStore.getStoringContacts()), 1) self.assertEqual(len(self.node._dataStore.getStoringContacts()), 1)

View file

@ -32,7 +32,7 @@ class TreeRoutingTableTest(unittest.TestCase):
""" Test to see if distance method returns correct result""" """ Test to see if distance method returns correct result"""
# testList holds a couple 3-tuple (variable1, variable2, result) # testList holds a couple 3-tuple (variable1, variable2, result)
basicTestList = [(bytes([170] * 48), bytes([85] * 48), long(hexlify(bytes([255] * 48)), 16))] basicTestList = [(bytes(b'\xaa' * 48), bytes(b'\x55' * 48), long(hexlify(bytes(b'\xff' * 48)), 16))]
for test in basicTestList: for test in basicTestList:
result = Distance(test[0])(test[1]) result = Distance(test[0])(test[1])
@ -139,7 +139,7 @@ class TreeRoutingTableTest(unittest.TestCase):
Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact Test that a bucket is not split if it is full, but the new contact is not closer than the kth closest contact
""" """
self.routingTable._parentNodeID = bytes(48 * [255]) self.routingTable._parentNodeID = bytes(48 * b'\xff')
node_ids = [ node_ids = [
b"100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", b"100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",