diff --git a/Gopkg.lock b/Gopkg.lock index 9b975d9..771c76f 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -146,9 +146,16 @@ packages = [ "errors", "null", - "querytools" + "querytools", + "stopOnce" ] - revision = "fe6bc5bb14af1013b1cabb944d4d06413e7e2c8e" + revision = "a5d371ca4780841b033afe486a94f8eb80a94259" + +[[projects]] + branch = "master" + name = "github.com/lyoshenka/bencode" + packages = ["."] + revision = "d522839ac797fc43269dae6a04a1f8be475a915d" [[projects]] name = "github.com/miekg/dns" @@ -186,12 +193,6 @@ revision = "e57e3eeb33f795204c1ca35f56c44f83227c6e66" version = "v1.0.0" -[[projects]] - branch = "master" - name = "github.com/zeebo/bencode" - packages = ["."] - revision = "d522839ac797fc43269dae6a04a1f8be475a915d" - [[projects]] branch = "master" name = "golang.org/x/crypto" @@ -238,6 +239,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "21c8c5a2ce6478383360f22fd1ddf6344167c332cd0d32121b048abd97ca5cac" + inputs-digest = "ca9cc627801c67c407d872b78606a706e783135aa0665e0daf3688a69fad3712" solver-name = "gps-cdcl" solver-version = 1 diff --git a/cmd/dht.go b/cmd/dht.go index 0c73d07..4b913c3 100644 --- a/cmd/dht.go +++ b/cmd/dht.go @@ -26,5 +26,5 @@ func dhtCmd(cmd *cobra.Command, args []string) { log.Fatal(err) } - dht.Run() + dht.Start() } diff --git a/dht/bitmap.go b/dht/bitmap.go index a1236ca..e8d9caf 100644 --- a/dht/bitmap.go +++ b/dht/bitmap.go @@ -18,6 +18,10 @@ func (b bitmap) Hex() string { return hex.EncodeToString(b[0:nodeIDLength]) } +func (b bitmap) HexShort() string { + return hex.EncodeToString(b[0:nodeIDLength])[:8] +} + func (b bitmap) Equals(other bitmap) bool { for k := range b { if b[k] != other[k] { diff --git a/dht/dht.go b/dht/dht.go index 66abd66..a6688bd 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -1,10 +1,13 @@ package dht import ( + "context" "net" + "sync" "time" "github.com/lbryio/errors.go" + "github.com/lbryio/lbry.go/stopOnce" log "github.com/sirupsen/logrus" "github.com/spf13/cast" @@ -62,6 +65,7 @@ func NewStandardConfig() *Config { type UDPConn interface { ReadFromUDP([]byte) (int, *net.UDPAddr, error) WriteToUDP([]byte, *net.UDPAddr) (int, error) + SetReadDeadline(time.Time) error SetWriteDeadline(time.Time) error } @@ -74,6 +78,7 @@ type DHT struct { packets chan packet store *peerStore tm *transactionManager + stop *stopOnce.Stopper } // New returns a DHT pointer. If config is nil, then config will be set to the default config. @@ -114,6 +119,7 @@ func New(config *Config) (*DHT, error) { rt: newRoutingTable(node), packets: make(chan packet), store: newPeerStore(), + stop: stopOnce.New(), } d.tm = newTransactionManager(d) return d, nil @@ -140,37 +146,52 @@ func (dht *DHT) init() error { // listen receives message from udp. func (dht *DHT) listen() { - go func() { - buf := make([]byte, 8192) - for { - n, raddr, err := dht.conn.ReadFromUDP(buf) - if err != nil { - log.Errorf("udp read error: %v", err) - continue - } else if raddr == nil { - log.Errorf("udp read with no raddr") - continue - } - dht.packets <- packet{data: buf[:n], raddr: raddr} + buf := make([]byte, 8192) + for { + select { + case <-dht.stop.Chan(): + return + default: } - }() + + dht.conn.SetReadDeadline(time.Now().Add(2 * time.Second)) // need this to periodically check shutdown chan + + n, raddr, err := dht.conn.ReadFromUDP(buf) + if err != nil { + if e, ok := err.(net.Error); !ok || !e.Timeout() { + log.Errorf("udp read error: %v", err) + } + continue + } else if raddr == nil { + log.Errorf("udp read with no raddr") + continue + } + + dht.packets <- packet{data: buf[:n], raddr: raddr} + } } // join makes current node join the dht network. func (dht *DHT) join() { + // get real node IDs and add them to the routing table for _, addr := range dht.conf.SeedNodes { raddr, err := net.ResolveUDPAddr(network, addr) if err != nil { + log.Errorln(err) continue } - _ = raddr + tmpNode := Node{id: newRandomBitmap(), ip: raddr.IP, port: raddr.Port} + res := dht.tm.Send(tmpNode, &Request{Method: pingMethod}) + if res == nil { + log.Errorf("[%s] join: no response from seed node %s", dht.node.id.HexShort(), addr) + } + } - // NOTE: Temporary node has NO node id. - //dht.transactionManager.findNode( - // &node{addr: raddr}, - // dht.node.id.RawString(), - //) + // now call iterativeFind on yourself + _, err := dht.FindNodes(dht.node.id) + if err != nil { + log.Error(err) } } @@ -181,24 +202,33 @@ func (dht *DHT) runHandler() { select { case pkt = <-dht.packets: handlePacket(dht, pkt) + case <-dht.stop.Chan(): + return } } } -// Run starts the dht. -func (dht *DHT) Run() error { +// Start starts the dht +func (dht *DHT) Start() error { err := dht.init() if err != nil { return err } - dht.listen() + go dht.listen() + go dht.runHandler() + dht.join() - log.Info("DHT ready") - dht.runHandler() + log.Infof("[%s] DHT ready", dht.node.id.HexShort()) return nil } +// Shutdown shuts down the dht +func (dht *DHT) Shutdown() { + log.Infof("[%s] DHT shutting down", dht.node.id.HexShort()) + dht.stop.Stop() +} + func printState(dht *DHT) { t := time.NewTicker(60 * time.Second) for { @@ -210,68 +240,239 @@ func printState(dht *DHT) { } } -//func (dht *DHT) Get(hash bitmap) ([]Node, error) { -// return iterativeFindNode(dht, hash) -//} -// -//func iterativeFindNode(dht *DHT, hash bitmap) ([]Node, error) { -// shortlist := dht.rt.FindClosest(hash, alpha) -// if len(shortlist) == 0 { -// return nil, errors.Err("no nodes in routing table") -// } -// -// pending := make(chan *Node) -// contacted := make(map[bitmap]bool) -// contactedMutex := &sync.Mutex{} -// closestNodeMutex := &sync.Mutex{} -// closestNode := shortlist[0] -// wg := sync.WaitGroup{} -// -// for i := 0; i < alpha; i++ { -// wg.Add(1) -// go func() { -// defer wg.Done() -// for { -// node, ok := <-pending -// if !ok { -// return -// } -// -// contactedMutex.Lock() -// if _, ok := contacted[node.id]; ok { -// contactedMutex.Unlock() -// continue -// } -// contacted[node.id] = true -// contactedMutex.Unlock() -// -// res := dht.tm.Send(node, &Request{ -// NodeID: dht.node.id.RawString(), -// Method: findNodeMethod, -// Args: []string{hash.RawString()}, -// }) -// if res == nil { -// // remove node from shortlist -// continue -// } -// -// for _, n := range res.FindNodeData { -// pending <- &n -// closestNodeMutex.Lock() -// if n.id.Xor(hash).Less(closestNode.id.Xor(hash)) { -// closestNode = &n -// } -// closestNodeMutex.Unlock() -// } -// } -// }() -// } -// -// for _, n := range shortlist { -// pending <- n -// } -// -// wg.Wait() -// -// return nil, nil -//} +func (dht *DHT) FindNodes(hash bitmap) ([]Node, error) { + nf := newNodeFinder(dht, hash, false) + res, err := nf.Find() + if err != nil { + return nil, err + } + return res.Nodes, nil +} + +func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) { + nf := newNodeFinder(dht, hash, true) + res, err := nf.Find() + if err != nil { + return nil, false, err + } + return res.Nodes, res.Found, nil +} + +type nodeFinder struct { + findValue bool // true if we're using findValue + target bitmap + dht *DHT + + done *stopOnce.Stopper + + findValueMutex *sync.Mutex + findValueResult []Node + + activeNodesMutex *sync.Mutex + activeNodes []Node + + shortlistMutex *sync.Mutex + shortlist []Node + + contactedMutex *sync.RWMutex + contacted map[bitmap]bool +} + +type findNodeResponse struct { + Found bool + Nodes []Node +} + +func newNodeFinder(dht *DHT, target bitmap, findValue bool) *nodeFinder { + return &nodeFinder{ + dht: dht, + target: target, + findValue: findValue, + findValueMutex: &sync.Mutex{}, + activeNodesMutex: &sync.Mutex{}, + contactedMutex: &sync.RWMutex{}, + shortlistMutex: &sync.Mutex{}, + contacted: make(map[bitmap]bool), + done: stopOnce.New(), + } +} + +func (nf *nodeFinder) Find() (findNodeResponse, error) { + log.Debugf("[%s] starting an iterative Find() for %s (findValue is %t)", nf.dht.node.id.HexShort(), nf.target.HexShort(), nf.findValue) + nf.appendNewToShortlist(nf.dht.rt.GetClosest(nf.target, alpha)) + if len(nf.shortlist) == 0 { + return findNodeResponse{}, errors.Err("no nodes in routing table") + } + + wg := &sync.WaitGroup{} + + for i := 0; i < alpha; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + nf.iterationWorker(i + 1) + }(i) + } + + wg.Wait() + + // TODO: what to do if we have less than K active nodes, shortlist is empty, but we + // TODO: have other nodes in our routing table whom we have not contacted. prolly contact them? + + result := findNodeResponse{} + if nf.findValue && len(nf.findValueResult) > 0 { + result.Found = true + result.Nodes = nf.findValueResult + } else { + result.Nodes = nf.activeNodes + if len(result.Nodes) > bucketSize { + result.Nodes = result.Nodes[:bucketSize] + } + } + + return result, nil +} + +func (nf *nodeFinder) iterationWorker(num int) { + log.Debugf("[%s] starting worker %d", nf.dht.node.id.HexShort(), num) + defer func() { log.Debugf("[%s] stopping worker %d", nf.dht.node.id.HexShort(), num) }() + + for { + maybeNode := nf.popFromShortlist() + if maybeNode == nil { + // TODO: block if there are pending requests out from other workers. there may be more shortlist values coming + log.Debugf("[%s] no more nodes in short list", nf.dht.node.id.HexShort()) + return + } + node := *maybeNode + + if node.id.Equals(nf.dht.node.id) { + continue // cannot contact self + } + + req := &Request{Args: []string{nf.target.RawString()}} + if nf.findValue { + req.Method = findValueMethod + } else { + req.Method = findNodeMethod + } + + log.Debugf("[%s] contacting %s", nf.dht.node.id.HexShort(), node.id.HexShort()) + + var res *Response + ctx, cancel := context.WithCancel(context.Background()) + resCh := nf.dht.tm.SendAsync(ctx, node, req) + select { + case res = <-resCh: + case <-nf.done.Chan(): + log.Debugf("[%s] worker %d: canceled", nf.dht.node.id.HexShort(), num) + cancel() + return + } + + if res == nil { + // nothing to do, response timed out + } else if nf.findValue && res.FindValueKey != "" { + log.Debugf("[%s] worker %d: got value", nf.dht.node.id.HexShort(), num) + nf.findValueMutex.Lock() + nf.findValueResult = res.FindNodeData + nf.findValueMutex.Unlock() + nf.done.Stop() + return + } else { + log.Debugf("[%s] worker %d: got more contacts", nf.dht.node.id.HexShort(), num) + nf.insertIntoActiveList(node) + nf.markContacted(node) + nf.appendNewToShortlist(res.FindNodeData) + } + + if nf.isSearchFinished() { + log.Debugf("[%s] worker %d: search is finished", nf.dht.node.id.HexShort(), num) + nf.done.Stop() + return + } + } +} + +func (nf *nodeFinder) filterContacted(nodes []Node) []Node { + nf.contactedMutex.RLock() + defer nf.contactedMutex.RUnlock() + filtered := []Node{} + for _, n := range nodes { + if ok := nf.contacted[n.id]; !ok { + filtered = append(filtered, n) + } + } + return filtered +} + +func (nf *nodeFinder) markContacted(node Node) { + nf.contactedMutex.Lock() + defer nf.contactedMutex.Unlock() + nf.contacted[node.id] = true +} + +func (nf *nodeFinder) appendNewToShortlist(nodes []Node) { + nf.shortlistMutex.Lock() + defer nf.shortlistMutex.Unlock() + nf.shortlist = append(nf.shortlist, nf.filterContacted(nodes)...) + sortNodesInPlace(nf.shortlist, nf.target) +} + +func (nf *nodeFinder) popFromShortlist() *Node { + nf.shortlistMutex.Lock() + defer nf.shortlistMutex.Unlock() + if len(nf.shortlist) == 0 { + return nil + } + first := nf.shortlist[0] + nf.shortlist = nf.shortlist[1:] + return &first +} + +func (nf *nodeFinder) insertIntoActiveList(node Node) { + nf.activeNodesMutex.Lock() + defer nf.activeNodesMutex.Unlock() + + inserted := false + for i, n := range nf.activeNodes { + if node.id.Xor(nf.target).Less(n.id.Xor(nf.target)) { + nf.activeNodes = append(nf.activeNodes[:i], append([]Node{node}, nf.activeNodes[i:]...)...) + inserted = true + } + } + if !inserted { + nf.activeNodes = append(nf.activeNodes, node) + } +} + +func (nf *nodeFinder) isSearchFinished() bool { + if nf.findValue && len(nf.findValueResult) > 0 { + // if we have a result, always break + return true + } + + select { + case <-nf.done.Chan(): + return true + default: + } + + nf.shortlistMutex.Lock() + defer nf.shortlistMutex.Unlock() + + if len(nf.shortlist) == 0 { + // no more nodes to contact + return true + } + + nf.activeNodesMutex.Lock() + defer nf.activeNodesMutex.Unlock() + + if len(nf.activeNodes) >= bucketSize && nf.activeNodes[bucketSize-1].id.Xor(nf.target).Less(nf.shortlist[0].id.Xor(nf.target)) { + // we have at least K active nodes, and we don't have any closer nodes yet to contact + return true + } + + return false +} diff --git a/dht/dht_test.go b/dht/dht_test.go new file mode 100644 index 0000000..6866adc --- /dev/null +++ b/dht/dht_test.go @@ -0,0 +1,44 @@ +package dht + +import ( + "testing" + "time" + + "github.com/davecgh/go-spew/spew" +) + +func TestDHT_FindNodes(t *testing.T) { + //log.SetLevel(log.DebugLevel) + + id1 := newRandomBitmap() + id2 := newRandomBitmap() + id3 := newRandomBitmap() + + seedIP := "127.0.0.1:21216" + + dht, err := New(&Config{Address: seedIP, NodeID: id1.Hex()}) + if err != nil { + t.Fatal(err) + } + go dht.Start() + + time.Sleep(1 * time.Second) + + dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}}) + if err != nil { + t.Fatal(err) + } + go dht2.Start() + + time.Sleep(1 * time.Second) // give dhts a chance to connect + + dht3, err := New(&Config{Address: "127.0.0.1:21218", NodeID: id3.Hex(), SeedNodes: []string{seedIP}}) + if err != nil { + t.Fatal(err) + } + go dht3.Start() + + time.Sleep(1 * time.Second) // give dhts a chance to connect + + spew.Dump(dht3.FindNodes(id2)) +} diff --git a/dht/message.go b/dht/message.go index f826e0d..facabda 100644 --- a/dht/message.go +++ b/dht/message.go @@ -183,7 +183,7 @@ func (r Response) ArgsDebug() string { str += "for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " " } for _, c := range r.FindNodeData { - str += c.Addr().String() + ":" + c.id.Hex()[:8] + ", " + str += c.Addr().String() + ":" + c.id.HexShort() + ", " } return str[:len(str)-2] // chomp off last ", " } @@ -229,7 +229,22 @@ func (r *Response) UnmarshalBencode(b []byte) error { err = bencode.DecodeBytes(raw.Data, &r.Data) if err != nil { - err = bencode.DecodeBytes(raw.Data, r.FindNodeData) + var rawData map[string]bencode.RawMessage + err = bencode.DecodeBytes(raw.Data, &rawData) + if err != nil { + return err + } + + var rawContacts bencode.RawMessage + var ok bool + if rawContacts, ok = rawData["contacts"]; !ok { + for k, v := range rawData { + r.FindValueKey = k + rawContacts = v + break + } + } + err = bencode.DecodeBytes(rawContacts, &r.FindNodeData) if err != nil { return err } diff --git a/dht/routing_table.go b/dht/routing_table.go index a433ac2..3d914c7 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -94,11 +94,11 @@ func (n *Node) UnmarshalBencode(b []byte) error { } type SortedNode struct { - node *Node + node Node xorDistanceToTarget bitmap } -type byXorDistance []*SortedNode +type byXorDistance []SortedNode func (a byXorDistance) Len() int { return len(a) } func (a byXorDistance) Swap(i, j int) { a[i], a[j] = a[j], a[i] } @@ -128,9 +128,18 @@ func (rt *RoutingTable) BucketInfo() string { bucketInfo := []string{} for i, b := range rt.buckets { - count := countInList(b) + count := 0 + ids := "" + for curr := b.Front(); curr != nil; curr = curr.Next() { + count++ + if ids != "" { + ids += ", " + } + ids += curr.Value.(Node).id.HexShort() + } + if count > 0 { - bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: %d", i, count)) + bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: (%d) %s", i, count, ids)) } } if len(bucketInfo) == 0 { @@ -139,12 +148,12 @@ func (rt *RoutingTable) BucketInfo() string { return strings.Join(bucketInfo, "\n") } -func (rt *RoutingTable) Update(node *Node) { +func (rt *RoutingTable) Update(node Node) { rt.lock.Lock() defer rt.lock.Unlock() bucketNum := bucketFor(rt.node.id, node.id) bucket := rt.buckets[bucketNum] - element := findInList(bucket, rt.node.id) + element := findInList(bucket, node.id) if element == nil { if bucket.Len() >= bucketSize { // TODO: Ping front node first. Only remove if it does not respond @@ -167,13 +176,19 @@ func (rt *RoutingTable) RemoveByID(id bitmap) { } } -func (rt *RoutingTable) FindClosest(target bitmap, limit int) []*Node { +func (rt *RoutingTable) GetClosest(target bitmap, limit int) []Node { rt.lock.RLock() defer rt.lock.RUnlock() - var toSort []*SortedNode + var toSort []SortedNode + var bucketNum int + + if rt.node.id.Equals(target) { + bucketNum = 0 + } else { + bucketNum = bucketFor(rt.node.id, target) + } - bucketNum := bucketFor(rt.node.id, target) bucket := rt.buckets[bucketNum] toSort = appendNodes(toSort, bucket.Front(), target) @@ -190,7 +205,7 @@ func (rt *RoutingTable) FindClosest(target bitmap, limit int) []*Node { sort.Sort(byXorDistance(toSort)) - var nodes []*Node + var nodes []Node for _, c := range toSort { nodes = append(nodes, c.node) if len(nodes) >= limit { @@ -203,25 +218,17 @@ func (rt *RoutingTable) FindClosest(target bitmap, limit int) []*Node { func findInList(bucket *list.List, value bitmap) *list.Element { for curr := bucket.Front(); curr != nil; curr = curr.Next() { - if curr.Value.(*Node).id.Equals(value) { + if curr.Value.(Node).id.Equals(value) { return curr } } return nil } -func countInList(bucket *list.List) int { - count := 0 - for curr := bucket.Front(); curr != nil; curr = curr.Next() { - count++ - } - return count -} - -func appendNodes(nodes []*SortedNode, start *list.Element, target bitmap) []*SortedNode { +func appendNodes(nodes []SortedNode, start *list.Element, target bitmap) []SortedNode { for curr := start; curr != nil; curr = curr.Next() { - node := curr.Value.(*Node) - nodes = append(nodes, &SortedNode{node, node.id.Xor(target)}) + node := curr.Value.(Node) + nodes = append(nodes, SortedNode{node, node.id.Xor(target)}) } return nodes } @@ -232,3 +239,17 @@ func bucketFor(id bitmap, target bitmap) int { } return numBuckets - 1 - target.Xor(id).PrefixLen() } + +func sortNodesInPlace(nodes []Node, target bitmap) { + toSort := make([]SortedNode, len(nodes)) + + for i, n := range nodes { + toSort[i] = SortedNode{n, n.id.Xor(target)} + } + + sort.Sort(byXorDistance(toSort)) + + for i, c := range toSort { + nodes[i] = c.node + } +} diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go index 1d33016..c66b46e 100644 --- a/dht/routing_table_test.go +++ b/dht/routing_table_test.go @@ -6,15 +6,41 @@ import ( "testing" ) +func TestRoutingTable_bucketFor(t *testing.T) { + target := newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") + var tests = []struct { + id bitmap + target bitmap + expected int + }{ + {newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), target, 0}, + {newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), target, 1}, + {newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), target, 1}, + {newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004"), target, 2}, + {newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005"), target, 2}, + {newBitmapFromHex("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f"), target, 3}, + {newBitmapFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010"), target, 4}, + {newBitmapFromHex("F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), target, 383}, + {newBitmapFromHex("F0000000000000000000000000000000F0000000000000000000000000F0000000000000000000000000000000000000"), target, 383}, + } + + for _, tt := range tests { + bucket := bucketFor(tt.id, tt.target) + if bucket != tt.expected { + t.Errorf("bucketFor(%s, %s) => %d, want %d", tt.id.Hex(), tt.target.Hex(), bucket, tt.expected) + } + } +} + func TestRoutingTable(t *testing.T) { n1 := newBitmapFromHex("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") n2 := newBitmapFromHex("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") n3 := newBitmapFromHex("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") rt := newRoutingTable(&Node{n1, net.ParseIP("127.0.0.1"), 8000}) - rt.Update(&Node{n2, net.ParseIP("127.0.0.1"), 8001}) - rt.Update(&Node{n3, net.ParseIP("127.0.0.1"), 8002}) + rt.Update(Node{n2, net.ParseIP("127.0.0.1"), 8001}) + rt.Update(Node{n3, net.ParseIP("127.0.0.1"), 8002}) - contacts := rt.FindClosest(newBitmapFromHex("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1) + contacts := rt.GetClosest(newBitmapFromHex("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1) if len(contacts) != 1 { t.Fail() return @@ -23,7 +49,7 @@ func TestRoutingTable(t *testing.T) { t.Error(contacts[0]) } - contacts = rt.FindClosest(n2, 10) + contacts = rt.GetClosest(n2, 10) if len(contacts) != 2 { t.Error(len(contacts)) return diff --git a/dht/rpc.go b/dht/rpc.go index 35db813..ac53fd9 100644 --- a/dht/rpc.go +++ b/dht/rpc.go @@ -25,12 +25,13 @@ func newMessageID() string { // handlePacke handles packets received from udp. func handlePacket(dht *DHT, pkt packet) { - //log.Infof("Received message from %s:%s : %s\n", pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data)) + //log.Infof("[%s] Received message from %s:%s : %s\n", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data)) var data map[string]interface{} err := bencode.DecodeBytes(pkt.data, &data) if err != nil { - log.Errorf("error decoding data: %s\n%s", err, pkt.data) + log.Errorf("error decoding data: %s", err) + log.Errorf(hex.EncodeToString(pkt.data)) return } @@ -48,16 +49,17 @@ func handlePacket(dht *DHT, pkt packet) { log.Errorln(err) return } - log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args)) + log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args)) handleRequest(dht, pkt.raddr, request) case responseType: response := Response{} err = bencode.DecodeBytes(pkt.data, &response) if err != nil { + log.Errorln(err) return } - log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.Data) + log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.ArgsDebug()) handleResponse(dht, pkt.raddr, response) case errorType: @@ -67,7 +69,7 @@ func handlePacket(dht *DHT, pkt packet) { ExceptionType: data[headerPayloadField].(string), Response: getArgs(data[headerArgsField]), } - log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType) + log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType) handleError(dht, pkt.raddr, e) default: @@ -130,17 +132,17 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { return } - node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port} + node := Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port} dht.rt.Update(node) } func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) { nodeID := newBitmapFromString(request.Args[0]) - closestNodes := dht.rt.FindClosest(nodeID, bucketSize) + closestNodes := dht.rt.GetClosest(nodeID, bucketSize) if len(closestNodes) > 0 { response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))} for i, n := range closestNodes { - response.FindNodeData[i] = *n + response.FindNodeData[i] = n } send(dht, addr, response) } else { @@ -155,25 +157,25 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) { tx.res <- &response } - node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port} + node := Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port} dht.rt.Update(node) } // handleError handles errors received from udp. func handleError(dht *DHT, addr *net.UDPAddr, e Error) { spew.Dump(e) - node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port} + node := Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port} dht.rt.Update(node) } -// send sends data to the udp. +// send sends data to a udp address func send(dht *DHT, addr *net.UDPAddr, data Message) error { if req, ok := data.(Request); ok { - log.Debugf("[%s] query %s: sending request to %s : %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], addr.String(), req.Method, argsToString(req.Args)) + log.Debugf("[%s] query %s: sending request to %s : %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), req.Method, argsToString(req.Args)) } else if res, ok := data.(Response); ok { - log.Debugf("[%s] query %s: sending response to %s : %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], addr.String(), res.ArgsDebug()) + log.Debugf("[%s] query %s: sending response to %s : %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), res.ArgsDebug()) } else { - log.Debugf("[%s] %s", dht.node.id.Hex()[:8], spew.Sdump(data)) + log.Debugf("[%s] %s", dht.node.id.HexShort(), spew.Sdump(data)) } encoded, err := bencode.EncodeBytes(data) if err != nil { diff --git a/dht/rpc_test.go b/dht/rpc_test.go index fc511db..6742d87 100644 --- a/dht/rpc_test.go +++ b/dht/rpc_test.go @@ -53,6 +53,10 @@ func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { return len(b), nil } +func (t testUDPConn) SetReadDeadline(tm time.Time) error { + return nil +} + func (t testUDPConn) SetWriteDeadline(tm time.Time) error { return nil } @@ -69,8 +73,9 @@ func TestPing(t *testing.T) { t.Fatal(err) } dht.conn = conn - dht.listen() + go dht.listen() go dht.runHandler() + defer dht.Shutdown() messageID := newMessageID() @@ -164,8 +169,9 @@ func TestStore(t *testing.T) { } dht.conn = conn - dht.listen() + go dht.listen() go dht.runHandler() + defer dht.Shutdown() messageID := newMessageID() blobHashToStore := newRandomBitmap().RawString() @@ -257,15 +263,16 @@ func TestFindNode(t *testing.T) { t.Fatal(err) } dht.conn = conn - dht.listen() + go dht.listen() go dht.runHandler() + defer dht.Shutdown() nodesToInsert := 3 var nodes []Node for i := 0; i < nodesToInsert; i++ { n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} nodes = append(nodes, n) - dht.rt.Update(&n) + dht.rt.Update(n) } messageID := newMessageID() @@ -334,15 +341,16 @@ func TestFindValueExisting(t *testing.T) { } dht.conn = conn - dht.listen() + go dht.listen() go dht.runHandler() + defer dht.Shutdown() nodesToInsert := 3 var nodes []Node for i := 0; i < nodesToInsert; i++ { n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} nodes = append(nodes, n) - dht.rt.Update(&n) + dht.rt.Update(n) } //data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565") @@ -418,15 +426,16 @@ func TestFindValueFallbackToFindNode(t *testing.T) { } dht.conn = conn - dht.listen() + go dht.listen() go dht.runHandler() + defer dht.Shutdown() nodesToInsert := 3 var nodes []Node for i := 0; i < nodesToInsert; i++ { n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} nodes = append(nodes, n) - dht.rt.Update(&n) + dht.rt.Update(n) } messageID := newMessageID() diff --git a/dht/transaction_manager.go b/dht/transaction_manager.go index 3ed9444..122a1a2 100644 --- a/dht/transaction_manager.go +++ b/dht/transaction_manager.go @@ -1,6 +1,7 @@ package dht import ( + "context" "net" "sync" "time" @@ -10,7 +11,7 @@ import ( // query represents the query data included queried node and query-formed data. type transaction struct { - node *Node + node Node req *Request res chan *Response } @@ -60,37 +61,53 @@ func (tm *transactionManager) Find(id string, addr *net.UDPAddr) *transaction { return t } -func (tm *transactionManager) Send(node *Node, req *Request) *Response { +func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Request) <-chan *Response { if node.id.Equals(tm.dht.node.id) { log.Error("sending query to self") return nil } - req.ID = newMessageID() - trans := &transaction{ - node: node, - req: req, - res: make(chan *Response), - } + ch := make(chan *Response, 1) - tm.insert(trans) - defer tm.delete(trans.req.ID) + go func() { + defer close(ch) - for i := 0; i < udpRetry; i++ { - if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil { - log.Error(err) - break + req.ID = newMessageID() + req.NodeID = tm.dht.node.id.RawString() + trans := &transaction{ + node: node, + req: req, + res: make(chan *Response), } - select { - case res := <-trans.res: - return res - case <-time.After(udpTimeout): - } - } + tm.insert(trans) + defer tm.delete(trans.req.ID) - tm.dht.rt.RemoveByID(trans.node.id) - return nil + for i := 0; i < udpRetry; i++ { + if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil { + log.Error(err) + continue // try again? return? + } + + select { + case res := <-trans.res: + ch <- res + return + case <-ctx.Done(): + return + case <-time.After(udpTimeout): + } + } + + // if request timed out each time + tm.dht.rt.RemoveByID(trans.node.id) + }() + + return ch +} + +func (tm *transactionManager) Send(node Node, req *Request) *Response { + return <-tm.SendAsync(context.Background(), node, req) } // Count returns the number of transactions in the manager