diff --git a/cmd/dht.go b/cmd/dht.go index 1a6a458..1a8f5d5 100644 --- a/cmd/dht.go +++ b/cmd/dht.go @@ -80,7 +80,7 @@ func dhtCmd(cmd *cobra.Command, args []string) { //d.WaitUntilJoined() nodes := 10 - dhts := dht.TestingCreateDHT(nodes) + _, dhts := dht.TestingCreateDHT(nodes) defer func() { for _, d := range dhts { go d.Shutdown() diff --git a/dht/bitmap.go b/dht/bitmap.go index cf91aa4..8ede91d 100644 --- a/dht/bitmap.go +++ b/dht/bitmap.go @@ -60,6 +60,26 @@ func (b Bitmap) PrefixLen() int { return numBuckets } +// ZeroPrefix returns a copy of b with the first n bits set to 0 +// https://stackoverflow.com/a/23192263/182709 +func (b Bitmap) ZeroPrefix(n int) Bitmap { + var ret Bitmap + copy(ret[:], b[:]) + +Outer: + for i := range ret { + for j := 0; j < 8; j++ { + if i*8+j < n { + ret[i] &= ^(1 << uint(7-j)) + } else { + break Outer + } + } + } + + return ret +} + func (b Bitmap) MarshalBencode() ([]byte, error) { str := string(b[:]) return bencode.EncodeBytes(str) diff --git a/dht/bitmap_test.go b/dht/bitmap_test.go index 3f3dbc0..4b32f6e 100644 --- a/dht/bitmap_test.go +++ b/dht/bitmap_test.go @@ -99,23 +99,55 @@ func TestBitmapMarshalEmbedded2(t *testing.T) { func TestBitmap_PrefixLen(t *testing.T) { tt := []struct { - str string + hex string len int }{ - {len: 0, str: "F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, - {len: 0, str: "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, - {len: 1, str: "700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, - {len: 1, str: "400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, - {len: 384, str: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, - {len: 383, str: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"}, - {len: 382, str: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"}, - {len: 382, str: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"}, + {len: 0, hex: "F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + {len: 0, hex: "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + {len: 1, hex: "700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + {len: 1, hex: "400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + {len: 384, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + {len: 383, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"}, + {len: 382, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"}, + {len: 382, hex: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"}, } for _, test := range tt { - len := BitmapFromHexP(test.str).PrefixLen() + len := BitmapFromHexP(test.hex).PrefixLen() if len != test.len { - t.Errorf("got prefix len %d; expected %d for %s", len, test.len, test.str) + t.Errorf("got prefix len %d; expected %d for %s", len, test.len, test.hex) + } + } +} + +func TestBitmap_ZeroPrefix(t *testing.T) { + original := BitmapFromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + + tt := []struct { + zeros int + expected string + }{ + {zeros: -123, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {zeros: 0, expected: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {zeros: 1, expected: "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {zeros: 69, expected: "000000000000000007ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}, + {zeros: 383, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"}, + {zeros: 384, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + {zeros: 400, expected: "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + } + + for _, test := range tt { + expected := BitmapFromHexP(test.expected) + actual := original.ZeroPrefix(test.zeros) + if !actual.Equals(expected) { + t.Errorf("%d zeros: got %s; expected %s", test.zeros, actual.Hex(), expected.Hex()) + } + } + + for i := 0; i < nodeIDLength*8; i++ { + b := original.ZeroPrefix(i) + if b.PrefixLen() != i { + t.Errorf("got prefix len %d; expected %d for %s", b.PrefixLen(), i, b.Hex()) } } } diff --git a/dht/bootstrap.go b/dht/bootstrap.go index db95f42..a6423f9 100644 --- a/dht/bootstrap.go +++ b/dht/bootstrap.go @@ -1,7 +1,212 @@ package dht -// DHT represents a DHT node. +import ( + "context" + "math/rand" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + bootstrapDefaultRefreshDuration = 15 * time.Minute +) + +type nullStore struct{} + +func (n nullStore) Upsert(id Bitmap, c Contact) {} +func (n nullStore) Get(id Bitmap) []Contact { return nil } +func (n nullStore) CountStoredHashes() int { return 0 } + +type nullRoutingTable struct{} + +// TODO: the bootstrap logic *could* be implemented just in the routing table, without a custom request handler +// TODO: the only tricky part is triggering the ping when Fresh is called, as the rt doesnt have access to the node + +func (n nullRoutingTable) Update(c Contact) {} // this +func (n nullRoutingTable) Fresh(c Contact) {} // this +func (n nullRoutingTable) Fail(c Contact) {} // this +func (n nullRoutingTable) GetClosest(id Bitmap, limit int) []Contact { return nil } // this +func (n nullRoutingTable) Count() int { return 0 } +func (n nullRoutingTable) GetIDsForRefresh(d time.Duration) []Bitmap { return nil } +func (n nullRoutingTable) BucketInfo() string { return "" } + type BootstrapNode struct { - // node - node *Node + Node + + initialPingInterval time.Duration + checkInterval time.Duration + + nlock *sync.RWMutex + nodes []peer + nodeKeys map[Bitmap]int +} + +// New returns a BootstrapNode pointer. +func NewBootstrapNode(id Bitmap, initialPingInterval, rePingInterval time.Duration) *BootstrapNode { + b := &BootstrapNode{ + Node: *NewNode(id), + + initialPingInterval: initialPingInterval, + checkInterval: rePingInterval, + + nlock: &sync.RWMutex{}, + nodes: make([]peer, 0), + nodeKeys: make(map[Bitmap]int), + } + + b.rt = &nullRoutingTable{} + b.store = &nullStore{} + b.requestHandler = b.handleRequest + + return b +} + +// Add manually adds a contact +func (b *BootstrapNode) Add(c Contact) { + b.upsert(c) +} + +// Connect connects to the given connection and starts any background threads necessary +func (b *BootstrapNode) Connect(conn UDPConn) error { + err := b.Node.Connect(conn) + if err != nil { + return err + } + + log.Debugf("[%s] bootstrap: node connected", b.id.HexShort()) + + go func() { + t := time.NewTicker(b.checkInterval / 5) + for { + select { + case <-t.C: + b.check() + case <-b.stop.Chan(): + return + } + } + }() + + return nil +} + +// ypsert adds the contact to the list, or updates the lastPinged time +func (b *BootstrapNode) upsert(c Contact) { + b.nlock.Lock() + defer b.nlock.Unlock() + + if i, exists := b.nodeKeys[c.id]; exists { + log.Debugf("[%s] bootstrap: touching contact %s", b.id.HexShort(), b.nodes[i].contact.id.HexShort()) + b.nodes[i].Touch() + return + } + + log.Debugf("[%s] bootstrap: adding new contact %s", b.id.HexShort(), c.id.HexShort()) + b.nodeKeys[c.id] = len(b.nodes) + b.nodes = append(b.nodes, peer{c, time.Now(), 0}) +} + +// remove removes the contact from the list +func (b *BootstrapNode) remove(c Contact) { + b.nlock.Lock() + defer b.nlock.Unlock() + + i, exists := b.nodeKeys[c.id] + if !exists { + return + } + + log.Debugf("[%s] bootstrap: removing contact %s", b.id.HexShort(), c.id.HexShort()) + b.nodes = append(b.nodes[:i], b.nodes[i+1:]...) + delete(b.nodeKeys, c.id) +} + +// get returns up to `limit` random contacts from the list +func (b *BootstrapNode) get(limit int) []Contact { + b.nlock.RLock() + defer b.nlock.RUnlock() + + if len(b.nodes) < limit { + limit = len(b.nodes) + } + + ret := make([]Contact, limit) + for i, k := range randKeys(len(b.nodes))[:limit] { + ret[i] = b.nodes[k].contact + } + + return ret +} + +// ping pings a node. if the node responds, it is added to the list. otherwise, it is removed +func (b *BootstrapNode) ping(c Contact) { + b.stopWG.Add(1) + defer b.stopWG.Done() + + ctx, cancel := context.WithCancel(context.Background()) + resCh := b.SendAsync(ctx, c, Request{Method: pingMethod}) + + var res *Response + + select { + case res = <-resCh: + case <-b.stop.Chan(): + cancel() + return + } + + if res != nil && res.Data == pingSuccessResponse { + b.upsert(c) + } else { + b.remove(c) + } +} + +func (b *BootstrapNode) check() { + b.nlock.RLock() + defer b.nlock.RUnlock() + + for i := range b.nodes { + if !b.nodes[i].ActiveInLast(b.checkInterval) { + go b.ping(b.nodes[i].contact) + } + } +} + +// handleRequest handles the requests received from udp. +func (b *BootstrapNode) handleRequest(addr *net.UDPAddr, request Request) { + switch request.Method { + case pingMethod: + b.sendMessage(addr, Response{ID: request.ID, NodeID: b.id, Data: pingSuccessResponse}) + case findNodeMethod: + if request.Arg == nil { + log.Errorln("request is missing arg") + return + } + b.sendMessage(addr, Response{ + ID: request.ID, + NodeID: b.id, + Contacts: b.get(bucketSize), + }) + } + + go func() { + log.Debugf("[%s] bootstrap: queuing %s to ping", b.id.HexShort(), request.NodeID.HexShort()) + <-time.After(b.initialPingInterval) + b.ping(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port}) + }() +} + +func randKeys(max int) []int { + keys := make([]int, max) + for k := range keys { + keys[k] = k + } + rand.Shuffle(max, func(i, j int) { + keys[i], keys[j] = keys[j], keys[i] + }) + return keys } diff --git a/dht/bootstrap_test.go b/dht/bootstrap_test.go new file mode 100644 index 0000000..8b45dee --- /dev/null +++ b/dht/bootstrap_test.go @@ -0,0 +1,20 @@ +package dht + +import ( + "net" + "testing" +) + +func TestBootstrapPing(t *testing.T) { + b := NewBootstrapNode(RandomBitmapP(), 10, bootstrapDefaultRefreshDuration) + + listener, err := net.ListenPacket(network, "127.0.0.1:54320") + if err != nil { + panic(err) + } + + b.Connect(listener.(*net.UDPConn)) + defer b.Shutdown() + + b.Shutdown() +} diff --git a/dht/dht.go b/dht/dht.go index 9561fe3..f9217f5 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -34,6 +34,8 @@ const ( udpTimeout = 5 * time.Second udpMaxMessageLength = 1024 // bytes. I think our longest message is ~676 bytes, so I rounded up + maxPeerFails = 3 // after this many failures, a peer is considered bad and will be removed from the routing table + tExpire = 24 * time.Hour // the time after which a key/value pair expires; this is a time-to-live (TTL) from the original publication date tRefresh = 1 * time.Hour // the time after which an otherwise unaccessed bucket must be refreshed tReplicate = 1 * time.Hour // the interval between Kademlia replication events, when a node is required to publish its entire database @@ -97,15 +99,10 @@ func New(config *Config) (*DHT, error) { return nil, err } - node, err := NewNode(contact.id) - if err != nil { - return nil, err - } - d := &DHT{ conf: config, contact: contact, - node: node, + node: NewNode(contact.id), stop: stopOnce.New(), stopWG: &sync.WaitGroup{}, joined: make(chan struct{}), @@ -136,7 +133,8 @@ func (dht *DHT) join() { } // now call iterativeFind on yourself - _, err := dht.Get(dht.node.id) + nf := newContactFinder(dht.node, dht.node.id, false) + _, err := nf.Find() if err != nil { log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error()) } @@ -227,15 +225,18 @@ func (dht *DHT) storeOnNode(hash Bitmap, node Contact) { dht.stopWG.Add(1) defer dht.stopWG.Done() - resCh := dht.node.SendAsync(context.Background(), node, Request{ + ctx, cancel := context.WithCancel(context.Background()) + resCh := dht.node.SendAsync(ctx, node, Request{ Method: findValueMethod, Arg: &hash, }) + var res *Response select { case res = <-resCh: case <-dht.stop.Chan(): + cancel() return } @@ -243,7 +244,8 @@ func (dht *DHT) storeOnNode(hash Bitmap, node Contact) { return // request timed out } - dht.node.SendAsync(context.Background(), node, Request{ + ctx, cancel = context.WithCancel(context.Background()) + resCh = dht.node.SendAsync(ctx, node, Request{ Method: storeMethod, StoreArgs: &storeArgs{ BlobHash: hash, @@ -254,6 +256,14 @@ func (dht *DHT) storeOnNode(hash Bitmap, node Contact) { }, }, }) + + go func() { + select { + case <-resCh: + case <-dht.stop.Chan(): + cancel() + } + }() } func (dht *DHT) PrintState() { diff --git a/dht/dht_test.go b/dht/dht_test.go index 12c4425..3b0a61d 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -13,11 +13,12 @@ import ( // TODO: make a dht with X nodes, have them all join, then ensure that every node appears at least once in another node's routing table func TestNodeFinder_FindNodes(t *testing.T) { - dhts := TestingCreateDHT(3) + bs, dhts := TestingCreateDHT(3) defer func() { for i := range dhts { dhts[i].Shutdown() } + bs.Shutdown() }() nf := newContactFinder(dhts[2].node, RandomBitmapP(), false) @@ -31,38 +32,61 @@ func TestNodeFinder_FindNodes(t *testing.T) { t.Fatal("something was found, but it should not have been") } - if len(foundNodes) != 1 { - t.Errorf("expected 1 node, found %d", len(foundNodes)) + if len(foundNodes) != 3 { + t.Errorf("expected 3 node, found %d", len(foundNodes)) } + foundBootstrap := false foundOne := false - //foundTwo := false + foundTwo := false for _, n := range foundNodes { + if n.id.Equals(bs.id) { + foundBootstrap = true + } if n.id.Equals(dhts[0].node.id) { foundOne = true } - //if n.id.Equals(dhts[1].node.c.id) { - // foundTwo = true - //} + if n.id.Equals(dhts[1].node.id) { + foundTwo = true + } } + if !foundBootstrap { + t.Errorf("did not find bootstrap node %s", bs.id.Hex()) + } if !foundOne { t.Errorf("did not find first node %s", dhts[0].node.id.Hex()) } - //if !foundTwo { - // t.Errorf("did not find second node %s", dhts[1].node.c.id.Hex()) - //} + if !foundTwo { + t.Errorf("did not find second node %s", dhts[1].node.id.Hex()) + } } -func TestNodeFinder_FindValue(t *testing.T) { - dhts := TestingCreateDHT(3) +func TestNodeFinder_FindNodes_NoBootstrap(t *testing.T) { + dhts := TestingCreateDHTNoBootstrap(3, nil) defer func() { for i := range dhts { dhts[i].Shutdown() } }() + nf := newContactFinder(dhts[2].node, RandomBitmapP(), false) + _, err := nf.Find() + if err == nil { + t.Fatal("contact finder should have errored saying that there are no contacts in the routing table") + } +} + +func TestNodeFinder_FindValue(t *testing.T) { + bs, dhts := TestingCreateDHT(3) + defer func() { + for i := range dhts { + dhts[i].Shutdown() + } + bs.Shutdown() + }() + blobHashToFind := RandomBitmapP() nodeToFind := Contact{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4), port: 5678} dhts[0].node.store.Upsert(blobHashToFind, nodeToFind) @@ -91,11 +115,12 @@ func TestDHT_LargeDHT(t *testing.T) { rand.Seed(time.Now().UnixNano()) log.Println("if this takes longer than 20 seconds, its stuck. idk why it gets stuck sometimes, but its a bug.") nodes := 100 - dhts := TestingCreateDHT(nodes) + bs, dhts := TestingCreateDHT(nodes) defer func() { for _, d := range dhts { go d.Shutdown() } + bs.Shutdown() time.Sleep(1 * time.Second) }() @@ -115,5 +140,5 @@ func TestDHT_LargeDHT(t *testing.T) { } wg.Wait() - dhts[1].PrintState() + dhts[len(dhts)-1].PrintState() } diff --git a/dht/node.go b/dht/node.go index 0bdef9e..92b3e10 100644 --- a/dht/node.go +++ b/dht/node.go @@ -24,6 +24,7 @@ type packet struct { } // UDPConn allows using a mocked connection to test sending/receiving data +// TODO: stop mocking this and use the real thing type UDPConn interface { ReadFromUDP([]byte) (int, *net.UDPAddr, error) WriteToUDP([]byte, *net.UDPAddr) (int, error) @@ -32,6 +33,8 @@ type UDPConn interface { Close() error } +type RequestHandlerFunc func(addr *net.UDPAddr, request Request) + type Node struct { // the node's id id Bitmap @@ -45,20 +48,24 @@ type Node struct { transactions map[messageID]*transaction // routing table - rt *routingTable + rt RoutingTable // data store - store *peerStore + store Store + // overrides for request handlers + requestHandler RequestHandlerFunc + + // stop the node neatly and clean up after itself stop *stopOnce.Stopper stopWG *sync.WaitGroup } // New returns a Node pointer. -func NewNode(id Bitmap) (*Node, error) { - n := &Node{ +func NewNode(id Bitmap) *Node { + return &Node{ id: id, rt: newRoutingTable(id), - store: newPeerStore(), + store: newStore(), txLock: &sync.RWMutex{}, transactions: make(map[messageID]*transaction), @@ -67,11 +74,9 @@ func NewNode(id Bitmap) (*Node, error) { stopWG: &sync.WaitGroup{}, tokens: &tokenManager{}, } - - n.tokens.Start(tokenSecretRotationInterval) - return n, nil } +// Connect connects to the given connection and starts any background threads necessary func (n *Node) Connect(conn UDPConn) error { n.conn = conn @@ -89,6 +94,8 @@ func (n *Node) Connect(conn UDPConn) error { // }() //} + n.tokens.Start(tokenSecretRotationInterval) + packets := make(chan packet) go func() { @@ -139,6 +146,8 @@ func (n *Node) Connect(conn UDPConn) error { } }() + n.startRoutingTableGrooming() + return nil } @@ -161,10 +170,9 @@ func (n *Node) handlePacket(pkt packet) { return } - // TODO: test this stuff more thoroughly - // the following is a bit of a hack, but it lets us avoid decoding every message twice // it depends on the data being a dict with 0 as the first key (so it starts with "d1:0i") and the message type as the first value + // TODO: test this more thoroughly switch pkt.data[5] { case '0' + requestType: @@ -210,9 +218,15 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) { return } + // if a handler is overridden, call it instead + if n.requestHandler != nil { + n.requestHandler(addr, request) + return + } + switch request.Method { default: - // n.send(addr, makeError(t, protocolError, "invalid q")) + //n.sendMessage(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-request-method"}) log.Errorln("invalid request method") return case pingMethod: @@ -263,7 +277,7 @@ func (n *Node) handleRequest(addr *net.UDPAddr, request Request) { // the routing table must only contain "good" nodes, which are nodes that reply to our requests // if a node is already good (aka in the table), its fine to refresh it // http://www.bittorrent.org/beps/bep_0005.html#routing-table - n.rt.UpdateIfExists(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port}) + n.rt.Fresh(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port}) } // handleResponse handles responses received from udp. @@ -279,7 +293,7 @@ func (n *Node) handleResponse(addr *net.UDPAddr, response Response) { // handleError handles errors received from udp. func (n *Node) handleError(addr *net.UDPAddr, e Error) { spew.Dump(e) - n.rt.UpdateIfExists(Contact{id: e.NodeID, ip: addr.IP, port: addr.Port}) + n.rt.Fresh(Contact{id: e.NodeID, ip: addr.IP, port: addr.Port}) } // send sends data to a udp address @@ -383,8 +397,8 @@ func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-ch } } - // if request timed out each time - n.rt.Remove(tx.contact.id) + // notify routing table about a failure to respond + n.rt.Fail(tx.contact) }() return ch @@ -402,3 +416,19 @@ func (n *Node) CountActiveTransactions() int { defer n.txLock.Unlock() return len(n.transactions) } + +func (n *Node) startRoutingTableGrooming() { + n.stopWG.Add(1) + go func() { + defer n.stopWG.Done() + refreshTicker := time.NewTicker(tRefresh / 5) // how often to check for buckets that need to be refreshed + for { + select { + case <-refreshTicker.C: + RoutingTableRefresh(n, tRefresh, n.stop.Chan()) + case <-n.stop.Chan(): + return + } + } + }() +} diff --git a/dht/node_finder.go b/dht/node_finder.go index fdf4f96..d2993a5 100644 --- a/dht/node_finder.go +++ b/dht/node_finder.go @@ -17,7 +17,8 @@ type contactFinder struct { target Bitmap node *Node - done *stopOnce.Stopper + done *stopOnce.Stopper + doneWG *sync.WaitGroup findValueMutex *sync.Mutex findValueResult []Contact @@ -48,10 +49,16 @@ func newContactFinder(node *Node, target Bitmap, findValue bool) *contactFinder shortlistMutex: &sync.Mutex{}, shortlistAdded: make(map[Bitmap]bool), done: stopOnce.New(), + doneWG: &sync.WaitGroup{}, outstandingRequestsMutex: &sync.RWMutex{}, } } +func (cf *contactFinder) Cancel() { + cf.done.Stop() + cf.doneWG.Wait() +} + func (cf *contactFinder) Find() (findNodeResponse, error) { if cf.findValue { log.Debugf("[%s] starting an iterative Find for the value %s", cf.node.id.HexShort(), cf.target.HexShort()) @@ -63,17 +70,15 @@ func (cf *contactFinder) Find() (findNodeResponse, error) { return findNodeResponse{}, errors.Err("no contacts in routing table") } - wg := &sync.WaitGroup{} - for i := 0; i < alpha; i++ { - wg.Add(1) + cf.doneWG.Add(1) go func(i int) { - defer wg.Done() + defer cf.doneWG.Done() cf.iterationWorker(i + 1) }(i) } - wg.Wait() + cf.doneWG.Wait() // TODO: what to do if we have less than K active contacts, shortlist is empty, but we // TODO: have other contacts in our routing table whom we have not contacted. prolly contact them @@ -133,7 +138,7 @@ func (cf *contactFinder) iterationWorker(num int) { if res == nil { // nothing to do, response timed out - log.Debugf("[%s] worker %d: timed out waiting for %s", cf.node.id.HexShort(), num, contact.id.HexShort()) + log.Debugf("[%s] worker %d: search canceled or timed out waiting for %s", cf.node.id.HexShort(), num, contact.id.HexShort()) } else if cf.findValue && res.FindValueKey != "" { log.Debugf("[%s] worker %d: got value", cf.node.id.HexShort(), num) cf.findValueMutex.Lock() diff --git a/dht/node_test.go b/dht/node_test.go index e58223b..c2212b6 100644 --- a/dht/node_test.go +++ b/dht/node_test.go @@ -2,93 +2,12 @@ package dht import ( "net" - "strconv" - "strings" "testing" "time" - "github.com/lbryio/errors.go" "github.com/lyoshenka/bencode" ) -type timeoutErr struct { - error -} - -func (t timeoutErr) Timeout() bool { - return true -} - -func (t timeoutErr) Temporary() bool { - return true -} - -// TODO: just use a normal net.Conn instead of this mock conn - -type testUDPPacket struct { - data []byte - addr *net.UDPAddr -} - -type testUDPConn struct { - addr *net.UDPAddr - toRead chan testUDPPacket - writes chan testUDPPacket - - readDeadline time.Time -} - -func newTestUDPConn(addr string) *testUDPConn { - parts := strings.Split(addr, ":") - if len(parts) != 2 { - panic("addr needs ip and port") - } - port, err := strconv.Atoi(parts[1]) - if err != nil { - panic(err) - } - return &testUDPConn{ - addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port}, - toRead: make(chan testUDPPacket), - writes: make(chan testUDPPacket), - } -} - -func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { - var timeoutCh <-chan time.Time - if !t.readDeadline.IsZero() { - timeoutCh = time.After(t.readDeadline.Sub(time.Now())) - } - - select { - case packet := <-t.toRead: - n := copy(b, packet.data) - return n, packet.addr, nil - case <-timeoutCh: - return 0, nil, timeoutErr{errors.Err("timeout")} - } -} - -func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { - t.writes <- testUDPPacket{data: b, addr: addr} - return len(b), nil -} - -func (t *testUDPConn) SetReadDeadline(tm time.Time) error { - t.readDeadline = tm - return nil -} - -func (t *testUDPConn) SetWriteDeadline(tm time.Time) error { - return nil -} - -func (t *testUDPConn) Close() error { - t.toRead = nil - t.writes = nil - return nil -} - func TestPing(t *testing.T) { dhtNodeID := RandomBitmapP() testNodeID := RandomBitmapP() @@ -271,7 +190,7 @@ func TestStore(t *testing.T) { } } - if len(dht.node.store.hashes) != 1 { + if dht.node.store.CountStoredHashes() != 1 { t.Error("dht store has wrong number of items") } @@ -517,164 +436,3 @@ func TestFindValueFallbackToFindNode(t *testing.T) { verifyContacts(t, contacts, nodes) } - -func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dhtNodeID string) { - if len(resp) != 4 { - t.Errorf("expected 4 response fields, got %d", len(resp)) - } - - _, ok := resp[headerTypeField] - if !ok { - t.Error("missing type field") - } else { - rType, ok := resp[headerTypeField].(int64) - if !ok { - t.Error("type is not an integer") - } else if rType != responseType { - t.Error("unexpected response type") - } - } - - _, ok = resp[headerMessageIDField] - if !ok { - t.Error("missing message id field") - } else { - rMessageID, ok := resp[headerMessageIDField].(string) - if !ok { - t.Error("message ID is not a string") - } else if rMessageID != string(id[:]) { - t.Error("unexpected message ID") - } - if len(rMessageID) != messageIDLength { - t.Errorf("message ID should be %d chars long", messageIDLength) - } - } - - _, ok = resp[headerNodeIDField] - if !ok { - t.Error("missing node id field") - } else { - rNodeID, ok := resp[headerNodeIDField].(string) - if !ok { - t.Error("node ID is not a string") - } else if rNodeID != dhtNodeID { - t.Error("unexpected node ID") - } - if len(rNodeID) != nodeIDLength { - t.Errorf("node ID should be %d chars long", nodeIDLength) - } - } -} - -func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) { - if len(contacts) != len(nodes) { - t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes)) - return - } - - foundNodes := make(map[string]bool) - - for _, c := range contacts { - contact, ok := c.([]interface{}) - if !ok { - t.Error("contact is not a list") - return - } - - if len(contact) != 3 { - t.Error("contact must be 3 items") - return - } - - var currNode Contact - currNodeFound := false - - id, ok := contact[0].(string) - if !ok { - t.Error("contact id is not a string") - } else { - if _, ok := foundNodes[id]; ok { - t.Errorf("contact %s appears multiple times", id) - continue - } - for _, n := range nodes { - if n.id.RawString() == id { - currNode = n - currNodeFound = true - foundNodes[id] = true - break - } - } - if !currNodeFound { - t.Errorf("unexpected contact %s", id) - continue - } - } - - ip, ok := contact[1].(string) - if !ok { - t.Error("contact IP is not a string") - } else if !currNode.ip.Equal(net.ParseIP(ip)) { - t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.ip.String()) - } - - port, ok := contact[2].(int64) - if !ok { - t.Error("contact port is not an int") - } else if int(port) != currNode.port { - t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.port) - } - } -} - -func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Contact) { - if len(contacts) != len(nodes) { - t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes)) - return - } - - foundNodes := make(map[string]bool) - - for _, c := range contacts { - compact, ok := c.(string) - if !ok { - t.Error("contact is not a string") - return - } - - contact := Contact{} - err := contact.UnmarshalCompact([]byte(compact)) - if err != nil { - t.Error(err) - return - } - - var currNode Contact - currNodeFound := false - - if _, ok := foundNodes[contact.id.Hex()]; ok { - t.Errorf("contact %s appears multiple times", contact.id.Hex()) - continue - } - for _, n := range nodes { - if n.id.Equals(contact.id) { - currNode = n - currNodeFound = true - foundNodes[contact.id.Hex()] = true - break - } - } - if !currNodeFound { - t.Errorf("unexpected contact %s", contact.id.Hex()) - continue - } - - if !currNode.ip.Equal(contact.ip) { - t.Errorf("contact IP mismatch. got %s; expected %s", contact.ip.String(), currNode.ip.String()) - } - - if contact.port != currNode.port { - t.Errorf("contact port mismatch. got %d; expected %d", contact.port, currNode.port) - } - } -} diff --git a/dht/routing_table.go b/dht/routing_table.go index 81eb6dd..3de4abb 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -8,6 +8,7 @@ import ( "sort" "strings" "sync" + "time" "github.com/lbryio/errors.go" @@ -110,31 +111,175 @@ func (a byXorDistance) Less(i, j int) bool { return a[i].xorDistanceToTarget.Less(a[j].xorDistanceToTarget) } -type routingTable struct { - id Bitmap - buckets [numBuckets]*list.List - lock *sync.RWMutex +// peer is a contact with extra freshness information +type peer struct { + contact Contact + lastActivity time.Time + numFailures int + //, + // + // } -func newRoutingTable(id Bitmap) *routingTable { - var rt routingTable - for i := range rt.buckets { - rt.buckets[i] = list.New() +func (p *peer) Touch() { + p.lastActivity = time.Now() + p.numFailures = 0 +} + +// ActiveSince returns whether a peer has responded in the last `d` duration +// this is used to check if the peer is "good", meaning that we believe the peer will respond to our requests +func (p *peer) ActiveInLast(d time.Duration) bool { + return time.Now().Sub(p.lastActivity) > d +} + +// IsBad returns whether a peer is "bad", meaning that it has failed to respond to multiple pings in a row +func (p *peer) IsBad(maxFalures int) bool { + return p.numFailures >= maxFalures +} + +// Fail marks a peer as having failed to respond. It returns whether or not the peer should be removed from the routing table +func (p *peer) Fail() { + p.numFailures++ +} + +// toPeer converts a generic *list.Element into a *peer +// this (along with newPeer) keeps all conversions between *list.Element and peer in one place +func toPeer(el *list.Element) *peer { + return el.Value.(*peer) +} + +// newPeer creates a new peer from a contact +// this (along with toPeer) keeps all conversions between *list.Element and peer in one place +func newPeer(c Contact) peer { + return peer{ + contact: c, } +} + +type bucket struct { + lock *sync.RWMutex + peers *list.List + lastUpdate time.Time +} + +// Len returns the number of peers in the bucket +func (b bucket) Len() int { + b.lock.RLock() + defer b.lock.RUnlock() + return b.peers.Len() +} + +// Contacts returns a slice of the bucket's contacts +func (b bucket) Contacts() []Contact { + b.lock.RLock() + defer b.lock.RUnlock() + contacts := make([]Contact, b.peers.Len()) + for i, curr := 0, b.peers.Front(); curr != nil; i, curr = i+1, curr.Next() { + contacts[i] = toPeer(curr).contact + } + return contacts +} + +// UpdateContact marks a contact as having been successfully contacted. if insertIfNew and the contact is does not exist yet, it is inserted +func (b *bucket) UpdateContact(c Contact, insertIfNew bool) { + b.lock.Lock() + defer b.lock.Unlock() + + element := find(c.id, b.peers) + if element != nil { + b.lastUpdate = time.Now() + toPeer(element).Touch() + b.peers.MoveToBack(element) + + } else if insertIfNew { + hasRoom := true + + if b.peers.Len() >= bucketSize { + hasRoom = false + for curr := b.peers.Front(); curr != nil; curr = curr.Next() { + if toPeer(curr).IsBad(maxPeerFails) { + // TODO: Ping contact first. Only remove if it does not respond + b.peers.Remove(curr) + hasRoom = true + break + } + } + } + + if hasRoom { + b.lastUpdate = time.Now() + peer := newPeer(c) + peer.Touch() + b.peers.PushBack(&peer) + } + } +} + +// FailContact marks a contact as having failed, and removes it if it failed too many times +func (b *bucket) FailContact(id Bitmap) { + b.lock.Lock() + defer b.lock.Unlock() + element := find(id, b.peers) + if element != nil { + // BEP5 says not to remove the contact until the bucket is full and you try to insert + toPeer(element).Fail() + } +} + +// find returns the contact in the bucket, or nil if the bucket does not contain the contact +func find(id Bitmap, peers *list.List) *list.Element { + for curr := peers.Front(); curr != nil; curr = curr.Next() { + if toPeer(curr).contact.id.Equals(id) { + return curr + } + } + return nil +} + +// NeedsRefresh returns true if bucket has not been updated in the last `refreshInterval`, false otherwise +func (b *bucket) NeedsRefresh(refreshInterval time.Duration) bool { + b.lock.RLock() + defer b.lock.RUnlock() + return time.Now().Sub(b.lastUpdate) > refreshInterval +} + +type RoutingTable interface { + Update(Contact) + Fresh(Contact) + Fail(Contact) + GetClosest(Bitmap, int) []Contact + Count() int + GetIDsForRefresh(time.Duration) []Bitmap + BucketInfo() string // for debugging +} + +type routingTableImpl struct { + id Bitmap + buckets [numBuckets]bucket +} + +func newRoutingTable(id Bitmap) *routingTableImpl { + var rt routingTableImpl rt.id = id - rt.lock = &sync.RWMutex{} + for i := range rt.buckets { + rt.buckets[i] = bucket{ + peers: list.New(), + lock: &sync.RWMutex{}, + } + } return &rt } -func (rt *routingTable) BucketInfo() string { - rt.lock.RLock() - defer rt.lock.RUnlock() - +func (rt *routingTableImpl) BucketInfo() string { var bucketInfo []string for i, b := range rt.buckets { - contents := bucketContents(b) - if contents != "" { - bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: %s", i, contents)) + if b.Len() > 0 { + contacts := b.Contacts() + s := make([]string, len(contacts)) + for j, c := range contacts { + s[j] = c.id.HexShort() + } + bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: (%d) %s", i, len(contacts), strings.Join(s, ", "))) } } if len(bucketInfo) == 0 { @@ -143,89 +288,41 @@ func (rt *routingTable) BucketInfo() string { return strings.Join(bucketInfo, "\n") } -func bucketContents(b *list.List) string { - count := 0 - ids := "" - for curr := b.Front(); curr != nil; curr = curr.Next() { - count++ - if ids != "" { - ids += ", " - } - ids += curr.Value.(Contact).id.HexShort() - } - - if count > 0 { - return fmt.Sprintf("(%d) %s", count, ids) - } else { - return "" - } -} - // Update inserts or refreshes a contact -func (rt *routingTable) Update(c Contact) { - rt.lock.Lock() - defer rt.lock.Unlock() - bucketNum := rt.bucketFor(c.id) - bucket := rt.buckets[bucketNum] - element := findInList(bucket, c.id) - if element == nil { - if bucket.Len() >= bucketSize { - // TODO: Ping front contact first. Only remove if it does not respond - bucket.Remove(bucket.Front()) - } - bucket.PushBack(c) - } else { - bucket.MoveToBack(element) - } +func (rt *routingTableImpl) Update(c Contact) { + rt.bucketFor(c.id).UpdateContact(c, true) } -// UpdateIfExists refreshes a contact if its already in the routing table -func (rt *routingTable) UpdateIfExists(c Contact) { - rt.lock.Lock() - defer rt.lock.Unlock() - bucketNum := rt.bucketFor(c.id) - bucket := rt.buckets[bucketNum] - element := findInList(bucket, c.id) - if element != nil { - bucket.MoveToBack(element) - } +// Fresh refreshes a contact if its already in the routing table +func (rt *routingTableImpl) Fresh(c Contact) { + rt.bucketFor(c.id).UpdateContact(c, false) } -func (rt *routingTable) Remove(id Bitmap) { - rt.lock.Lock() - defer rt.lock.Unlock() - bucketNum := rt.bucketFor(id) - bucket := rt.buckets[bucketNum] - element := findInList(bucket, rt.id) - if element != nil { - bucket.Remove(element) - } +// FailContact marks a contact as having failed, and removes it if it failed too many times +func (rt *routingTableImpl) Fail(c Contact) { + rt.bucketFor(c.id).FailContact(c.id) } -func (rt *routingTable) GetClosest(target Bitmap, limit int) []Contact { - rt.lock.RLock() - defer rt.lock.RUnlock() - +// GetClosest returns the closest `limit` contacts from the routing table +// It marks each bucket it accesses as having been accessed +func (rt *routingTableImpl) GetClosest(target Bitmap, limit int) []Contact { var toSort []sortedContact var bucketNum int if rt.id.Equals(target) { bucketNum = 0 } else { - bucketNum = rt.bucketFor(target) + bucketNum = rt.bucketNumFor(target) } - bucket := rt.buckets[bucketNum] - toSort = appendContacts(toSort, bucket.Front(), target) + toSort = appendContacts(toSort, rt.buckets[bucketNum], target) for i := 1; (bucketNum-i >= 0 || bucketNum+i < numBuckets) && len(toSort) < limit; i++ { if bucketNum-i >= 0 { - bucket = rt.buckets[bucketNum-i] - toSort = appendContacts(toSort, bucket.Front(), target) + toSort = appendContacts(toSort, rt.buckets[bucketNum-i], target) } if bucketNum+i < numBuckets { - bucket = rt.buckets[bucketNum+i] - toSort = appendContacts(toSort, bucket.Front(), target) + toSort = appendContacts(toSort, rt.buckets[bucketNum+i], target) } } @@ -242,43 +339,75 @@ func (rt *routingTable) GetClosest(target Bitmap, limit int) []Contact { return contacts } -func appendContacts(contacts []sortedContact, start *list.Element, target Bitmap) []sortedContact { - for curr := start; curr != nil; curr = curr.Next() { - c := toContact(curr) - contacts = append(contacts, sortedContact{c, c.id.Xor(target)}) +func appendContacts(contacts []sortedContact, b bucket, target Bitmap) []sortedContact { + for _, contact := range b.Contacts() { + contacts = append(contacts, sortedContact{contact, contact.id.Xor(target)}) } return contacts } // Count returns the number of contacts in the routing table -func (rt *routingTable) Count() int { - rt.lock.RLock() - defer rt.lock.RUnlock() +func (rt *routingTableImpl) Count() int { count := 0 for _, bucket := range rt.buckets { - for curr := bucket.Front(); curr != nil; curr = curr.Next() { - count++ - } + count = bucket.Len() } return count } -func (rt *routingTable) bucketFor(target Bitmap) int { +func (rt *routingTableImpl) bucketNumFor(target Bitmap) int { if rt.id.Equals(target) { panic("routing table does not have a bucket for its own id") } return numBuckets - 1 - target.Xor(rt.id).PrefixLen() } -func findInList(bucket *list.List, value Bitmap) *list.Element { - for curr := bucket.Front(); curr != nil; curr = curr.Next() { - if toContact(curr).id.Equals(value) { - return curr - } - } - return nil +func (rt *routingTableImpl) bucketFor(target Bitmap) *bucket { + return &rt.buckets[rt.bucketNumFor(target)] } -func toContact(el *list.Element) Contact { - return el.Value.(Contact) +func (rt *routingTableImpl) GetIDsForRefresh(refreshInterval time.Duration) []Bitmap { + var bitmaps []Bitmap + for i, bucket := range rt.buckets { + if bucket.NeedsRefresh(refreshInterval) { + bitmaps = append(bitmaps, RandomBitmapP().ZeroPrefix(i)) + } + } + return bitmaps +} + +// RoutingTableRefresh refreshes any buckets that need to be refreshed +// It returns a channel that will be closed when the refresh is done +func RoutingTableRefresh(n *Node, refreshInterval time.Duration, cancel <-chan struct{}) <-chan struct{} { + done := make(chan struct{}) + + var wg sync.WaitGroup + + for _, id := range n.rt.GetIDsForRefresh(refreshInterval) { + wg.Add(1) + go func(id Bitmap) { + defer wg.Done() + + nf := newContactFinder(n, id, false) + + if cancel != nil { + go func() { + select { + case <-cancel: + nf.Cancel() + case <-done: + } + }() + } + + nf.Find() + }(id) + } + + go func() { + wg.Wait() + close(done) + }() + + return done } diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go index e4451c3..cedb946 100644 --- a/dht/routing_table_test.go +++ b/dht/routing_table_test.go @@ -7,27 +7,26 @@ import ( ) func TestRoutingTable_bucketFor(t *testing.T) { - target := BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") + rt := newRoutingTable(BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")) var tests = []struct { id Bitmap - target Bitmap expected int }{ - {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), target, 0}, - {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), target, 1}, - {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), target, 1}, - {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004"), target, 2}, - {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005"), target, 2}, - {BitmapFromHexP("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f"), target, 3}, - {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010"), target, 4}, - {BitmapFromHexP("F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), target, 383}, - {BitmapFromHexP("F0000000000000000000000000000000F0000000000000000000000000F0000000000000000000000000000000000000"), target, 383}, + {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), 0}, + {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), 1}, + {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), 1}, + {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004"), 2}, + {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005"), 2}, + {BitmapFromHexP("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f"), 3}, + {BitmapFromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010"), 4}, + {BitmapFromHexP("F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 383}, + {BitmapFromHexP("F0000000000000000000000000000000F0000000000000000000000000F0000000000000000000000000000000000000"), 383}, } for _, tt := range tests { - bucket := bucketFor(tt.id, tt.target) + bucket := rt.bucketNumFor(tt.id) if bucket != tt.expected { - t.Errorf("bucketFor(%s, %s) => %d, want %d", tt.id.Hex(), tt.target.Hex(), bucket, tt.expected) + t.Errorf("bucketFor(%s, %s) => %d, want %d", tt.id.Hex(), rt.id.Hex(), bucket, tt.expected) } } } @@ -83,3 +82,7 @@ func TestCompactEncoding(t *testing.T) { t.Errorf("compact bytes not encoded correctly") } } + +func TestRoutingTableRefresh(t *testing.T) { + t.Skip("TODO: test routing table refreshing") +} diff --git a/dht/store.go b/dht/store.go index 70a4056..e8a5faf 100644 --- a/dht/store.go +++ b/dht/store.go @@ -2,7 +2,13 @@ package dht import "sync" -type peerStore struct { +type Store interface { + Upsert(Bitmap, Contact) + Get(Bitmap) []Contact + CountStoredHashes() int +} + +type storeImpl struct { // map of blob hashes to (map of node IDs to bools) hashes map[Bitmap]map[Bitmap]bool // stores the peers themselves, so they can be updated in one place @@ -10,14 +16,14 @@ type peerStore struct { lock sync.RWMutex } -func newPeerStore() *peerStore { - return &peerStore{ +func newStore() *storeImpl { + return &storeImpl{ hashes: make(map[Bitmap]map[Bitmap]bool), contacts: make(map[Bitmap]Contact), } } -func (s *peerStore) Upsert(blobHash Bitmap, contact Contact) { +func (s *storeImpl) Upsert(blobHash Bitmap, contact Contact) { s.lock.Lock() defer s.lock.Unlock() @@ -28,7 +34,7 @@ func (s *peerStore) Upsert(blobHash Bitmap, contact Contact) { s.contacts[contact.id] = contact } -func (s *peerStore) Get(blobHash Bitmap) []Contact { +func (s *storeImpl) Get(blobHash Bitmap) []Contact { s.lock.RLock() defer s.lock.RUnlock() @@ -45,11 +51,11 @@ func (s *peerStore) Get(blobHash Bitmap) []Contact { return contacts } -func (s *peerStore) RemoveTODO(contact Contact) { +func (s *storeImpl) RemoveTODO(contact Contact) { // TODO: remove peer from everywhere } -func (s *peerStore) CountStoredHashes() int { +func (s *storeImpl) CountStoredHashes() int { s.lock.RLock() defer s.lock.RUnlock() return len(s.hashes) diff --git a/dht/testing.go b/dht/testing.go index a22bd71..111138b 100644 --- a/dht/testing.go +++ b/dht/testing.go @@ -1,23 +1,40 @@ package dht -import "strconv" +import ( + "net" + "strconv" + "strings" + "testing" + "time" -func TestingCreateDHT(numNodes int) []*DHT { + "github.com/lbryio/errors.go" +) + +var testingDHTIP = "127.0.0.1" +var testingDHTFirstPort = 21000 + +func TestingCreateDHT(numNodes int) (*BootstrapNode, []*DHT) { + bootstrapAddress := testingDHTIP + ":" + strconv.Itoa(testingDHTFirstPort) + bootstrapNode := NewBootstrapNode(RandomBitmapP(), 0, bootstrapDefaultRefreshDuration) + listener, err := net.ListenPacket(network, bootstrapAddress) + if err != nil { + panic(err) + } + bootstrapNode.Connect(listener.(*net.UDPConn)) + + return bootstrapNode, TestingCreateDHTNoBootstrap(numNodes, []string{bootstrapAddress}) +} + +func TestingCreateDHTNoBootstrap(numNodes int, seeds []string) []*DHT { if numNodes < 1 { return nil } - ip := "127.0.0.1" - firstPort := 21000 + firstPort := testingDHTFirstPort + 1 dhts := make([]*DHT, numNodes) for i := 0; i < numNodes; i++ { - seeds := []string{} - if i > 0 { - seeds = []string{ip + ":" + strconv.Itoa(firstPort)} - } - - dht, err := New(&Config{Address: ip + ":" + strconv.Itoa(firstPort+i), NodeID: RandomBitmapP().Hex(), SeedNodes: seeds}) + dht, err := New(&Config{Address: testingDHTIP + ":" + strconv.Itoa(firstPort+i), NodeID: RandomBitmapP().Hex(), SeedNodes: seeds}) if err != nil { panic(err) } @@ -29,3 +46,242 @@ func TestingCreateDHT(numNodes int) []*DHT { return dhts } + +type timeoutErr struct { + error +} + +func (t timeoutErr) Timeout() bool { + return true +} + +func (t timeoutErr) Temporary() bool { + return true +} + +// TODO: just use a normal net.Conn instead of this mock conn + +type testUDPPacket struct { + data []byte + addr *net.UDPAddr +} + +type testUDPConn struct { + addr *net.UDPAddr + toRead chan testUDPPacket + writes chan testUDPPacket + + readDeadline time.Time +} + +func newTestUDPConn(addr string) *testUDPConn { + parts := strings.Split(addr, ":") + if len(parts) != 2 { + panic("addr needs ip and port") + } + port, err := strconv.Atoi(parts[1]) + if err != nil { + panic(err) + } + return &testUDPConn{ + addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port}, + toRead: make(chan testUDPPacket), + writes: make(chan testUDPPacket), + } +} + +func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { + var timeoutCh <-chan time.Time + if !t.readDeadline.IsZero() { + timeoutCh = time.After(t.readDeadline.Sub(time.Now())) + } + + select { + case packet := <-t.toRead: + n := copy(b, packet.data) + return n, packet.addr, nil + case <-timeoutCh: + return 0, nil, timeoutErr{errors.Err("timeout")} + } +} + +func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + t.writes <- testUDPPacket{data: b, addr: addr} + return len(b), nil +} + +func (t *testUDPConn) SetReadDeadline(tm time.Time) error { + t.readDeadline = tm + return nil +} + +func (t *testUDPConn) SetWriteDeadline(tm time.Time) error { + return nil +} + +func (t *testUDPConn) Close() error { + t.toRead = nil + t.writes = nil + return nil +} + +func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dhtNodeID string) { + if len(resp) != 4 { + t.Errorf("expected 4 response fields, got %d", len(resp)) + } + + _, ok := resp[headerTypeField] + if !ok { + t.Error("missing type field") + } else { + rType, ok := resp[headerTypeField].(int64) + if !ok { + t.Error("type is not an integer") + } else if rType != responseType { + t.Error("unexpected response type") + } + } + + _, ok = resp[headerMessageIDField] + if !ok { + t.Error("missing message id field") + } else { + rMessageID, ok := resp[headerMessageIDField].(string) + if !ok { + t.Error("message ID is not a string") + } else if rMessageID != string(id[:]) { + t.Error("unexpected message ID") + } + if len(rMessageID) != messageIDLength { + t.Errorf("message ID should be %d chars long", messageIDLength) + } + } + + _, ok = resp[headerNodeIDField] + if !ok { + t.Error("missing node id field") + } else { + rNodeID, ok := resp[headerNodeIDField].(string) + if !ok { + t.Error("node ID is not a string") + } else if rNodeID != dhtNodeID { + t.Error("unexpected node ID") + } + if len(rNodeID) != nodeIDLength { + t.Errorf("node ID should be %d chars long", nodeIDLength) + } + } +} + +func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) { + if len(contacts) != len(nodes) { + t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes)) + return + } + + foundNodes := make(map[string]bool) + + for _, c := range contacts { + contact, ok := c.([]interface{}) + if !ok { + t.Error("contact is not a list") + return + } + + if len(contact) != 3 { + t.Error("contact must be 3 items") + return + } + + var currNode Contact + currNodeFound := false + + id, ok := contact[0].(string) + if !ok { + t.Error("contact id is not a string") + } else { + if _, ok := foundNodes[id]; ok { + t.Errorf("contact %s appears multiple times", id) + continue + } + for _, n := range nodes { + if n.id.RawString() == id { + currNode = n + currNodeFound = true + foundNodes[id] = true + break + } + } + if !currNodeFound { + t.Errorf("unexpected contact %s", id) + continue + } + } + + ip, ok := contact[1].(string) + if !ok { + t.Error("contact IP is not a string") + } else if !currNode.ip.Equal(net.ParseIP(ip)) { + t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.ip.String()) + } + + port, ok := contact[2].(int64) + if !ok { + t.Error("contact port is not an int") + } else if int(port) != currNode.port { + t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.port) + } + } +} + +func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Contact) { + if len(contacts) != len(nodes) { + t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes)) + return + } + + foundNodes := make(map[string]bool) + + for _, c := range contacts { + compact, ok := c.(string) + if !ok { + t.Error("contact is not a string") + return + } + + contact := Contact{} + err := contact.UnmarshalCompact([]byte(compact)) + if err != nil { + t.Error(err) + return + } + + var currNode Contact + currNodeFound := false + + if _, ok := foundNodes[contact.id.Hex()]; ok { + t.Errorf("contact %s appears multiple times", contact.id.Hex()) + continue + } + for _, n := range nodes { + if n.id.Equals(contact.id) { + currNode = n + currNodeFound = true + foundNodes[contact.id.Hex()] = true + break + } + } + if !currNodeFound { + t.Errorf("unexpected contact %s", contact.id.Hex()) + continue + } + + if !currNode.ip.Equal(contact.ip) { + t.Errorf("contact IP mismatch. got %s; expected %s", contact.ip.String(), currNode.ip.String()) + } + + if contact.port != currNode.port { + t.Errorf("contact port mismatch. got %d; expected %d", contact.port, currNode.port) + } + } +} diff --git a/store/s3.go b/store/s3.go index 9332a53..0eac518 100644 --- a/store/s3.go +++ b/store/s3.go @@ -79,7 +79,7 @@ func (s *S3BlobStore) Get(hash string) ([]byte, error) { log.Debugf("Getting %s from S3", hash[:8]) defer func(t time.Time) { - log.Debugf("Getting %s took %s", hash[:8], time.Since(t).String()) + log.Debugf("Getting %s from S3 took %s", hash[:8], time.Since(t).String()) }(time.Now()) buf := &aws.WriteAtBuffer{}