From 81759d8b5a0cd9b63ea3d25185bcc4db4957573a Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Fri, 27 Apr 2018 20:16:12 -0400 Subject: [PATCH] move most dht code into Node --- dht/bootstrap.go | 7 + dht/dht.go | 219 +++++----------- dht/dht_test.go | 16 +- dht/message.go | 28 +-- dht/message_test.go | 8 +- dht/node.go | 403 ++++++++++++++++++++++++++++++ dht/node_finder.go | 229 +++++++++-------- dht/{rpc_test.go => node_test.go} | 81 +++--- dht/routing_table.go | 145 +++++------ dht/routing_table_test.go | 12 +- dht/rpc.go | 178 ------------- dht/store.go | 34 +-- dht/transaction_manager.go | 125 --------- 13 files changed, 758 insertions(+), 727 deletions(-) create mode 100644 dht/bootstrap.go create mode 100644 dht/node.go rename dht/{rpc_test.go => node_test.go} (92%) delete mode 100644 dht/rpc.go delete mode 100644 dht/transaction_manager.go diff --git a/dht/bootstrap.go b/dht/bootstrap.go new file mode 100644 index 0000000..db95f42 --- /dev/null +++ b/dht/bootstrap.go @@ -0,0 +1,7 @@ +package dht + +// DHT represents a DHT node. +type BootstrapNode struct { + // node + node *Node +} diff --git a/dht/dht.go b/dht/dht.go index 5d3cc6f..aef320d 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -11,9 +11,9 @@ import ( "github.com/lbryio/errors.go" "github.com/lbryio/lbry.go/stopOnce" + "github.com/spf13/cast" log "github.com/sirupsen/logrus" - "github.com/spf13/cast" ) func init() { @@ -42,12 +42,6 @@ const compactNodeInfoLength = nodeIDLength + 6 const tokenSecretRotationInterval = 5 * time.Minute // how often the token-generating secret is rotated -// packet represents the information receive from udp. -type packet struct { - data []byte - raddr *net.UDPAddr -} - // Config represents the configure of dht. type Config struct { // this node's address. format is `ip:port` @@ -72,33 +66,14 @@ func NewStandardConfig() *Config { } } -// UDPConn allows using a mocked connection to test sending/receiving data -type UDPConn interface { - ReadFromUDP([]byte) (int, *net.UDPAddr, error) - WriteToUDP([]byte, *net.UDPAddr) (int, error) - SetReadDeadline(time.Time) error - SetWriteDeadline(time.Time) error - Close() error -} - // DHT represents a DHT node. type DHT struct { // config conf *Config - // UDP connection for sending and receiving data - conn UDPConn - // the local dht node + // local contact + contact Contact + // node node *Node - // routing table - rt *routingTable - // channel of incoming packets - packets chan packet - // data store - store *peerStore - // transaction manager - tm *transactionManager - // token manager - tokens *tokenManager // stopper to shut down DHT stop *stopOnce.Stopper // wait group for all the things that need to be stopped when DHT shuts down @@ -113,107 +88,27 @@ func New(config *Config) (*DHT, error) { config = NewStandardConfig() } - var id Bitmap - if config.NodeID == "" { - id = RandomBitmapP() - } else { - id = BitmapFromHexP(config.NodeID) - } - - ip, port, err := net.SplitHostPort(config.Address) + contact, err := getContact(config.NodeID, config.Address) if err != nil { - return nil, errors.Err(err) - } else if ip == "" { - return nil, errors.Err("address does not contain an IP") - } else if port == "" { - return nil, errors.Err("address does not contain a port") + return nil, err } - portInt, err := cast.ToIntE(port) + node, err := NewNode(contact.id) if err != nil { - return nil, errors.Err(err) - } - - node := &Node{id: id, ip: net.ParseIP(ip), port: portInt} - if node.ip == nil { - return nil, errors.Err("invalid ip") + return nil, err } d := &DHT{ conf: config, + contact: contact, node: node, - rt: newRoutingTable(node), - packets: make(chan packet), - store: newPeerStore(), stop: stopOnce.New(), stopWG: &sync.WaitGroup{}, joined: make(chan struct{}), - tokens: &tokenManager{}, } - d.tm = newTransactionManager(d) - d.tokens.Start(tokenSecretRotationInterval) return d, nil } -// init initializes global variables. -func (dht *DHT) init() error { - listener, err := net.ListenPacket(network, dht.conf.Address) - if err != nil { - return errors.Err(err) - } - - dht.conn = listener.(*net.UDPConn) - - if dht.conf.PrintState > 0 { - go func() { - t := time.NewTicker(dht.conf.PrintState) - for { - dht.PrintState() - select { - case <-t.C: - case <-dht.stop.Chan(): - return - } - } - }() - } - - return nil -} - -// listen receives message from udp. -func (dht *DHT) listen() { - dht.stopWG.Add(1) - defer dht.stopWG.Done() - - buf := make([]byte, udpMaxMessageLength) - - for { - select { - case <-dht.stop.Chan(): - return - default: - } - - dht.conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) // need this to periodically check shutdown chan - n, raddr, err := dht.conn.ReadFromUDP(buf) - if err != nil { - if e, ok := err.(net.Error); !ok || !e.Timeout() { - log.Errorf("udp read error: %v", err) - } - continue - } else if raddr == nil { - log.Errorf("udp read with no raddr") - continue - } - - data := make([]byte, n) - copy(data, buf[:n]) // slices use the same underlying array, so we need a new one for each packet - - dht.packets <- packet{data: data, raddr: raddr} - } -} - // join makes current node join the dht network. func (dht *DHT) join() { defer close(dht.joined) // if anyone's waiting for join to finish, they'll know its done @@ -243,34 +138,21 @@ func (dht *DHT) join() { } } -func (dht *DHT) runHandler() { - dht.stopWG.Add(1) - defer dht.stopWG.Done() - - var pkt packet - - for { - select { - case pkt = <-dht.packets: - handlePacket(dht, pkt) - case <-dht.stop.Chan(): - return - } - } -} - // Start starts the dht func (dht *DHT) Start() error { - err := dht.init() + listener, err := net.ListenPacket(network, dht.conf.Address) + if err != nil { + return errors.Err(err) + } + conn := listener.(*net.UDPConn) + + err = dht.node.Connect(conn) if err != nil { return err } - go dht.listen() - go dht.runHandler() - dht.join() - log.Debugf("[%s] DHT ready on %s (%d nodes found during join)", dht.node.id.HexShort(), dht.node.Addr().String(), dht.rt.Count()) + log.Debugf("[%s] DHT ready on %s (%d nodes found during join)", dht.node.id.HexShort(), dht.contact.Addr().String(), dht.node.rt.Count()) return nil } @@ -286,8 +168,7 @@ func (dht *DHT) Shutdown() { log.Debugf("[%s] DHT shutting down", dht.node.id.HexShort()) dht.stop.Stop() dht.stopWG.Wait() - dht.tokens.Stop() - dht.conn.Close() + dht.node.Shutdown() log.Debugf("[%s] DHT stopped", dht.node.id.HexShort()) } @@ -298,8 +179,8 @@ func (dht *DHT) Ping(addr string) error { return err } - tmpNode := Node{id: RandomBitmapP(), ip: raddr.IP, port: raddr.Port} - res := dht.tm.Send(tmpNode, Request{Method: pingMethod}) + tmpNode := Contact{id: RandomBitmapP(), ip: raddr.IP, port: raddr.Port} + res := dht.node.Send(tmpNode, Request{Method: pingMethod}) if res == nil { return errors.Err("no response from node %s", addr) } @@ -308,22 +189,22 @@ func (dht *DHT) Ping(addr string) error { } // Get returns the list of nodes that have the blob for the given hash -func (dht *DHT) Get(hash Bitmap) ([]Node, error) { - nf := newNodeFinder(dht, hash, true) +func (dht *DHT) Get(hash Bitmap) ([]Contact, error) { + nf := newContactFinder(dht.node, hash, true) res, err := nf.Find() if err != nil { return nil, err } if res.Found { - return res.Nodes, nil + return res.Contacts, nil } return nil, nil } // Announce announces to the DHT that this node has the blob for the given hash func (dht *DHT) Announce(hash Bitmap) error { - nf := newNodeFinder(dht, hash, false) + nf := newContactFinder(dht.node, hash, false) res, err := nf.Find() if err != nil { return err @@ -331,18 +212,18 @@ func (dht *DHT) Announce(hash Bitmap) error { // TODO: if this node is closer than farthest peer, store locally and pop farthest peer - for _, node := range res.Nodes { + for _, node := range res.Contacts { go dht.storeOnNode(hash, node) } return nil } -func (dht *DHT) storeOnNode(hash Bitmap, node Node) { +func (dht *DHT) storeOnNode(hash Bitmap, node Contact) { dht.stopWG.Add(1) defer dht.stopWG.Done() - resCh := dht.tm.SendAsync(context.Background(), node, Request{ + resCh := dht.node.SendAsync(context.Background(), node, Request{ Method: findValueMethod, Arg: &hash, }) @@ -358,30 +239,30 @@ func (dht *DHT) storeOnNode(hash Bitmap, node Node) { return // request timed out } - dht.tm.SendAsync(context.Background(), node, Request{ + dht.node.SendAsync(context.Background(), node, Request{ Method: storeMethod, StoreArgs: &storeArgs{ BlobHash: hash, Value: storeArgsValue{ Token: res.Token, - LbryID: dht.node.id, - Port: dht.node.port, + LbryID: dht.contact.id, + Port: dht.contact.port, }, }, }) } func (dht *DHT) PrintState() { - log.Printf("DHT node %s at %s", dht.node.String(), time.Now().Format(time.RFC822Z)) - log.Printf("Outstanding transactions: %d", dht.tm.Count()) - log.Printf("Stored hashes: %d", dht.store.CountStoredHashes()) + log.Printf("DHT node %s at %s", dht.contact.String(), time.Now().Format(time.RFC822Z)) + log.Printf("Outstanding transactions: %d", dht.node.CountActiveTransactions()) + log.Printf("Stored hashes: %d", dht.node.store.CountStoredHashes()) log.Printf("Buckets:") - for _, line := range strings.Split(dht.rt.BucketInfo(), "\n") { + for _, line := range strings.Split(dht.node.rt.BucketInfo(), "\n") { log.Println(line) } } -func printNodeList(list []Node) { +func printNodeList(list []Contact) { for i, n := range list { log.Printf("%d) %s", i, n.String()) } @@ -414,3 +295,33 @@ func MakeTestDHT(numNodes int) []*DHT { return dhts } + +func getContact(nodeID, addr string) (Contact, error) { + var c Contact + if nodeID == "" { + c.id = RandomBitmapP() + } else { + c.id = BitmapFromHexP(nodeID) + } + + ip, port, err := net.SplitHostPort(addr) + if err != nil { + return c, errors.Err(err) + } else if ip == "" { + return c, errors.Err("address does not contain an IP") + } else if port == "" { + return c, errors.Err("address does not contain a port") + } + + c.ip = net.ParseIP(ip) + if c.ip == nil { + return c, errors.Err("invalid ip") + } + + c.port, err = cast.ToIntE(port) + if err != nil { + return c, errors.Err(err) + } + + return c, nil +} diff --git a/dht/dht_test.go b/dht/dht_test.go index 5955de4..7892a58 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -20,12 +20,12 @@ func TestNodeFinder_FindNodes(t *testing.T) { } }() - nf := newNodeFinder(dhts[2], RandomBitmapP(), false) + nf := newContactFinder(dhts[2].node, RandomBitmapP(), false) res, err := nf.Find() if err != nil { t.Fatal(err) } - foundNodes, found := res.Nodes, res.Found + foundNodes, found := res.Contacts, res.Found if found { t.Fatal("something was found, but it should not have been") @@ -42,7 +42,7 @@ func TestNodeFinder_FindNodes(t *testing.T) { if n.id.Equals(dhts[0].node.id) { foundOne = true } - //if n.id.Equals(dhts[1].node.id) { + //if n.id.Equals(dhts[1].node.c.id) { // foundTwo = true //} } @@ -51,7 +51,7 @@ func TestNodeFinder_FindNodes(t *testing.T) { t.Errorf("did not find first node %s", dhts[0].node.id.Hex()) } //if !foundTwo { - // t.Errorf("did not find second node %s", dhts[1].node.id.Hex()) + // t.Errorf("did not find second node %s", dhts[1].node.c.id.Hex()) //} } @@ -64,15 +64,15 @@ func TestNodeFinder_FindValue(t *testing.T) { }() blobHashToFind := RandomBitmapP() - nodeToFind := Node{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4), port: 5678} - dhts[0].store.Upsert(blobHashToFind, nodeToFind) + nodeToFind := Contact{id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4), port: 5678} + dhts[0].node.store.Upsert(blobHashToFind, nodeToFind) - nf := newNodeFinder(dhts[2], blobHashToFind, true) + nf := newContactFinder(dhts[2].node, blobHashToFind, true) res, err := nf.Find() if err != nil { t.Fatal(err) } - foundNodes, found := res.Nodes, res.Found + foundNodes, found := res.Contacts, res.Found if !found { t.Fatal("node was not found") diff --git a/dht/message.go b/dht/message.go index d693928..baefa34 100644 --- a/dht/message.go +++ b/dht/message.go @@ -223,7 +223,7 @@ type Response struct { ID messageID NodeID Bitmap Data string - FindNodeData []Node + Contacts []Contact FindValueKey string Token string } @@ -239,7 +239,7 @@ func (r Response) ArgsDebug() string { } str += "|" - for _, c := range r.FindNodeData { + for _, c := range r.Contacts { str += c.Addr().String() + ":" + c.id.HexShort() + "," } str = strings.TrimRight(str, ",") + "|" @@ -268,8 +268,8 @@ func (r Response) MarshalBencode() ([]byte, error) { } var contacts [][]byte - for _, n := range r.FindNodeData { - compact, err := n.MarshalCompact() + for _, c := range r.Contacts { + compact, err := c.MarshalCompact() if err != nil { return nil, err } @@ -282,12 +282,12 @@ func (r Response) MarshalBencode() ([]byte, error) { } else if r.Token != "" { // findValue failure falling back to findNode data[headerPayloadField] = map[string]interface{}{ - contactsField: r.FindNodeData, + contactsField: r.Contacts, tokenField: r.Token, } } else { // straight up findNode - data[headerPayloadField] = r.FindNodeData + data[headerPayloadField] = r.Contacts } return bencode.EncodeBytes(data) @@ -314,7 +314,7 @@ func (r *Response) UnmarshalBencode(b []byte) error { } // maybe data is a list of nodes (response to findNode)? - err = bencode.DecodeBytes(raw.Data, &r.FindNodeData) + err = bencode.DecodeBytes(raw.Data, &r.Contacts) if err == nil { return nil } @@ -335,25 +335,25 @@ func (r *Response) UnmarshalBencode(b []byte) error { } if contacts, ok := rawData[contactsField]; ok { - err = bencode.DecodeBytes(contacts, &r.FindNodeData) + err = bencode.DecodeBytes(contacts, &r.Contacts) if err != nil { return err } } else { for k, v := range rawData { r.FindValueKey = k - var compactNodes [][]byte - err = bencode.DecodeBytes(v, &compactNodes) + var compactContacts [][]byte + err = bencode.DecodeBytes(v, &compactContacts) if err != nil { return err } - for _, compact := range compactNodes { - var uncompactedNode Node - err = uncompactedNode.UnmarshalCompact(compact) + for _, compact := range compactContacts { + var c Contact + err = c.UnmarshalCompact(compact) if err != nil { return err } - r.FindNodeData = append(r.FindNodeData, uncompactedNode) + r.Contacts = append(r.Contacts, c) } break } diff --git a/dht/message_test.go b/dht/message_test.go index 863d7b5..5fa9d96 100644 --- a/dht/message_test.go +++ b/dht/message_test.go @@ -77,7 +77,7 @@ func TestBencodeFindNodesResponse(t *testing.T) { res := Response{ ID: newMessageID(), NodeID: RandomBitmapP(), - FindNodeData: []Node{ + Contacts: []Contact{ {id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678}, {id: RandomBitmapP(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765}, }, @@ -103,7 +103,7 @@ func TestBencodeFindValueResponse(t *testing.T) { NodeID: RandomBitmapP(), FindValueKey: RandomBitmapP().RawString(), Token: "arst", - FindNodeData: []Node{ + Contacts: []Contact{ {id: RandomBitmapP(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678}, }, } @@ -182,7 +182,7 @@ func compareResponses(t *testing.T, res, res2 Response) { if res.Token != res2.Token { t.Errorf("expected Token %s, got %s", res.Token, res2.Token) } - if !reflect.DeepEqual(res.FindNodeData, res2.FindNodeData) { - t.Errorf("expected FindNodeData %s, got %s", spew.Sdump(res.FindNodeData), spew.Sdump(res2.FindNodeData)) + if !reflect.DeepEqual(res.Contacts, res2.Contacts) { + t.Errorf("expected FindNodeData %s, got %s", spew.Sdump(res.Contacts), spew.Sdump(res2.Contacts)) } } diff --git a/dht/node.go b/dht/node.go new file mode 100644 index 0000000..0fd4bc5 --- /dev/null +++ b/dht/node.go @@ -0,0 +1,403 @@ +package dht + +import ( + "context" + "encoding/hex" + "net" + "strings" + "sync" + "time" + + "github.com/lbryio/errors.go" + "github.com/lbryio/lbry.go/stopOnce" + "github.com/lbryio/lbry.go/util" + + "github.com/davecgh/go-spew/spew" + "github.com/lyoshenka/bencode" + log "github.com/sirupsen/logrus" +) + +// packet represents the information receive from udp. +type packet struct { + data []byte + raddr *net.UDPAddr +} + +// UDPConn allows using a mocked connection to test sending/receiving data +type UDPConn interface { + ReadFromUDP([]byte) (int, *net.UDPAddr, error) + WriteToUDP([]byte, *net.UDPAddr) (int, error) + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error + Close() error +} + +type Node struct { + // TODO: replace Contact with id. ip and port aren't used except when connecting + id Bitmap + // UDP connection for sending and receiving data + conn UDPConn + // token manager + tokens *tokenManager + + txLock *sync.RWMutex + transactions map[messageID]*transaction + + // routing table + rt *routingTable + // data store + store *peerStore + + stop *stopOnce.Stopper + stopWG *sync.WaitGroup +} + +// New returns a Node pointer. +func NewNode(id Bitmap) (*Node, error) { + n := &Node{ + id: id, + rt: newRoutingTable(id), + store: newPeerStore(), + + txLock: &sync.RWMutex{}, + transactions: make(map[messageID]*transaction), + + stop: stopOnce.New(), + stopWG: &sync.WaitGroup{}, + tokens: &tokenManager{}, + } + + n.tokens.Start(tokenSecretRotationInterval) + return n, nil +} + +func (n *Node) Connect(conn UDPConn) error { + n.conn = conn + + //if dht.conf.PrintState > 0 { + // go func() { + // t := time.NewTicker(dht.conf.PrintState) + // for { + // dht.PrintState() + // select { + // case <-t.C: + // case <-dht.stop.Chan(): + // return + // } + // } + // }() + //} + + packets := make(chan packet) + + go func() { + n.stopWG.Add(1) + defer n.stopWG.Done() + + buf := make([]byte, udpMaxMessageLength) + + for { + select { + case <-n.stop.Chan(): + return + default: + } + + n.conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) // need this to periodically check shutdown chan + n, raddr, err := n.conn.ReadFromUDP(buf) + if err != nil { + if e, ok := err.(net.Error); !ok || !e.Timeout() { + log.Errorf("udp read error: %v", err) + } + continue + } else if raddr == nil { + log.Errorf("udp read with no raddr") + continue + } + + data := make([]byte, n) + copy(data, buf[:n]) // slices use the same underlying array, so we need a new one for each packet + + packets <- packet{data: data, raddr: raddr} + } + }() + + go func() { + n.stopWG.Add(1) + defer n.stopWG.Done() + + var pkt packet + + for { + select { + case pkt = <-packets: + n.handlePacket(pkt) + case <-n.stop.Chan(): + return + } + } + }() + + return nil +} + +// Shutdown shuts down the node +func (n *Node) Shutdown() { + log.Debugf("[%s] node shutting down", n.id.HexShort()) + n.stop.Stop() + n.stopWG.Wait() + n.tokens.Stop() + n.conn.Close() + log.Debugf("[%s] node stopped", n.id.HexShort()) +} + +// handlePacket handles packets received from udp. +func (n *Node) handlePacket(pkt packet) { + //log.Debugf("[%s] Received message from %s (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), len(pkt.data), hex.EncodeToString(pkt.data)) + + if !util.InSlice(string(pkt.data[0:5]), []string{"d1:0i", "di0ei"}) { + log.Errorf("[%s] data is not a well-formatted dict: (%d bytes) %s", n.id.HexShort(), len(pkt.data), hex.EncodeToString(pkt.data)) + return + } + + // TODO: test this stuff more thoroughly + + // the following is a bit of a hack, but it lets us avoid decoding every message twice + // it depends on the data being a dict with 0 as the first key (so it starts with "d1:0i") and the message type as the first value + + switch pkt.data[5] { + case '0' + requestType: + request := Request{} + err := bencode.DecodeBytes(pkt.data, &request) + if err != nil { + log.Errorf("[%s] error decoding request from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) + return + } + log.Debugf("[%s] query %s: received request from %s: %s(%s)", n.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, request.ArgsDebug()) + n.handleRequest(pkt.raddr, request) + + case '0' + responseType: + response := Response{} + err := bencode.DecodeBytes(pkt.data, &response) + if err != nil { + log.Errorf("[%s] error decoding response from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) + return + } + log.Debugf("[%s] query %s: received response from %s: %s", n.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug()) + n.handleResponse(pkt.raddr, response) + + case '0' + errorType: + e := Error{} + err := bencode.DecodeBytes(pkt.data, &e) + if err != nil { + log.Errorf("[%s] error decoding error from %s: %s: (%d bytes) %s", n.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) + return + } + log.Debugf("[%s] query %s: received error from %s: %s", n.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType) + n.handleError(pkt.raddr, e) + + default: + log.Errorf("[%s] invalid message type: %s", n.id.HexShort(), pkt.data[5]) + return + } +} + +// handleRequest handles the requests received from udp. +func (n *Node) handleRequest(addr *net.UDPAddr, request Request) { + if request.NodeID.Equals(n.id) { + log.Warn("ignoring self-request") + return + } + + switch request.Method { + default: + // n.send(addr, makeError(t, protocolError, "invalid q")) + log.Errorln("invalid request method") + return + case pingMethod: + n.send(addr, Response{ID: request.ID, NodeID: n.id, Data: pingSuccessResponse}) + case storeMethod: + // TODO: we should be sending the IP in the request, not just using the sender's IP + // TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ??? + if n.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) { + n.store.Upsert(request.StoreArgs.BlobHash, Contact{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port}) + n.send(addr, Response{ID: request.ID, NodeID: n.id, Data: storeSuccessResponse}) + } else { + n.send(addr, Error{ID: request.ID, NodeID: n.id, ExceptionType: "invalid-token"}) + } + case findNodeMethod: + if request.Arg == nil { + log.Errorln("request is missing arg") + return + } + n.send(addr, Response{ + ID: request.ID, + NodeID: n.id, + Contacts: n.rt.GetClosest(*request.Arg, bucketSize), + }) + + case findValueMethod: + if request.Arg == nil { + log.Errorln("request is missing arg") + return + } + + res := Response{ + ID: request.ID, + NodeID: n.id, + Token: n.tokens.Get(request.NodeID, addr), + } + + if contacts := n.store.Get(*request.Arg); len(contacts) > 0 { + res.FindValueKey = request.Arg.RawString() + res.Contacts = contacts + } else { + res.Contacts = n.rt.GetClosest(*request.Arg, bucketSize) + } + + n.send(addr, res) + } + + // nodes that send us requests should not be inserted, only refreshed. + // the routing table must only contain "good" nodes, which are nodes that reply to our requests + // if a node is already good (aka in the table), its fine to refresh it + // http://www.bittorrent.org/beps/bep_0005.html#routing-table + n.rt.UpdateIfExists(Contact{id: request.NodeID, ip: addr.IP, port: addr.Port}) +} + +// handleResponse handles responses received from udp. +func (n *Node) handleResponse(addr *net.UDPAddr, response Response) { + tx := n.txFind(response.ID, addr) + if tx != nil { + tx.res <- response + } + + n.rt.Update(Contact{id: response.NodeID, ip: addr.IP, port: addr.Port}) +} + +// handleError handles errors received from udp. +func (n *Node) handleError(addr *net.UDPAddr, e Error) { + spew.Dump(e) + n.rt.UpdateIfExists(Contact{id: e.NodeID, ip: addr.IP, port: addr.Port}) +} + +// send sends data to a udp address +func (n *Node) send(addr *net.UDPAddr, data Message) error { + encoded, err := bencode.EncodeBytes(data) + if err != nil { + return errors.Err(err) + } + + if req, ok := data.(Request); ok { + log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)", + n.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, req.ArgsDebug()) + } else if res, ok := data.(Response); ok { + log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s", + n.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug()) + } else { + log.Debugf("[%s] (%d bytes) %s", n.id.HexShort(), len(encoded), spew.Sdump(data)) + } + + n.conn.SetWriteDeadline(time.Now().Add(time.Second * 15)) + + _, err = n.conn.WriteToUDP(encoded, addr) + return errors.Err(err) +} + +// transaction represents a single query to the dht. it stores the queried contact, the request, and the response channel +type transaction struct { + contact Contact + req Request + res chan Response +} + +// insert adds a transaction to the manager. +func (n *Node) txInsert(tx *transaction) { + n.txLock.Lock() + defer n.txLock.Unlock() + n.transactions[tx.req.ID] = tx +} + +// delete removes a transaction from the manager. +func (n *Node) txDelete(id messageID) { + n.txLock.Lock() + defer n.txLock.Unlock() + delete(n.transactions, id) +} + +// Find finds a transaction for the given id. it optionally ensures that addr matches contact from transaction +func (n *Node) txFind(id messageID, addr *net.UDPAddr) *transaction { + n.txLock.RLock() + defer n.txLock.RUnlock() + + // TODO: also check that the response's nodeid matches the id you thought you sent to? + + t, ok := n.transactions[id] + if !ok || (addr != nil && t.contact.Addr().String() != addr.String()) { + return nil + } + + return t +} + +// SendAsync sends a transaction and returns a channel that will eventually contain the transaction response +// The response channel is closed when the transaction is completed or times out. +func (n *Node) SendAsync(ctx context.Context, contact Contact, req Request) <-chan *Response { + if contact.id.Equals(n.id) { + log.Error("sending query to self") + return nil + } + + ch := make(chan *Response, 1) + + go func() { + defer close(ch) + + req.ID = newMessageID() + req.NodeID = n.id + tx := &transaction{ + contact: contact, + req: req, + res: make(chan Response), + } + + n.txInsert(tx) + defer n.txDelete(tx.req.ID) + + for i := 0; i < udpRetry; i++ { + if err := n.send(contact.Addr(), tx.req); err != nil { + if !strings.Contains(err.Error(), "use of closed network connection") { // this only happens on localhost. real UDP has no connections + log.Error("send error: ", err) + } + continue // try again? return? + } + + select { + case res := <-tx.res: + ch <- &res + return + case <-ctx.Done(): + return + case <-time.After(udpTimeout): + } + } + + // if request timed out each time + n.rt.RemoveByID(tx.contact.id) + }() + + return ch +} + +// Send sends a transaction and blocks until the response is available. It returns a response, or nil +// if the transaction timed out. +func (n *Node) Send(contact Contact, req Request) *Response { + return <-n.SendAsync(context.Background(), contact, req) +} + +// Count returns the number of transactions in the manager +func (n *Node) CountActiveTransactions() int { + n.txLock.Lock() + defer n.txLock.Unlock() + return len(n.transactions) +} diff --git a/dht/node_finder.go b/dht/node_finder.go index 572bab5..fdf4f96 100644 --- a/dht/node_finder.go +++ b/dht/node_finder.go @@ -2,6 +2,7 @@ package dht import ( "context" + "sort" "sync" "time" @@ -11,21 +12,21 @@ import ( log "github.com/sirupsen/logrus" ) -type nodeFinder struct { +type contactFinder struct { findValue bool // true if we're using findValue target Bitmap - dht *DHT + node *Node done *stopOnce.Stopper findValueMutex *sync.Mutex - findValueResult []Node + findValueResult []Contact - activeNodesMutex *sync.Mutex - activeNodes []Node + activeContactsMutex *sync.Mutex + activeContacts []Contact shortlistMutex *sync.Mutex - shortlist []Node + shortlist []Contact shortlistAdded map[Bitmap]bool outstandingRequestsMutex *sync.RWMutex @@ -33,33 +34,33 @@ type nodeFinder struct { } type findNodeResponse struct { - Found bool - Nodes []Node + Found bool + Contacts []Contact } -func newNodeFinder(dht *DHT, target Bitmap, findValue bool) *nodeFinder { - return &nodeFinder{ - dht: dht, - target: target, - findValue: findValue, - findValueMutex: &sync.Mutex{}, - activeNodesMutex: &sync.Mutex{}, - shortlistMutex: &sync.Mutex{}, - shortlistAdded: make(map[Bitmap]bool), - done: stopOnce.New(), +func newContactFinder(node *Node, target Bitmap, findValue bool) *contactFinder { + return &contactFinder{ + node: node, + target: target, + findValue: findValue, + findValueMutex: &sync.Mutex{}, + activeContactsMutex: &sync.Mutex{}, + shortlistMutex: &sync.Mutex{}, + shortlistAdded: make(map[Bitmap]bool), + done: stopOnce.New(), outstandingRequestsMutex: &sync.RWMutex{}, } } -func (nf *nodeFinder) Find() (findNodeResponse, error) { - if nf.findValue { - log.Debugf("[%s] starting an iterative Find for the value %s", nf.dht.node.id.HexShort(), nf.target.HexShort()) +func (cf *contactFinder) Find() (findNodeResponse, error) { + if cf.findValue { + log.Debugf("[%s] starting an iterative Find for the value %s", cf.node.id.HexShort(), cf.target.HexShort()) } else { - log.Debugf("[%s] starting an iterative Find for nodes near %s", nf.dht.node.id.HexShort(), nf.target.HexShort()) + log.Debugf("[%s] starting an iterative Find for contacts near %s", cf.node.id.HexShort(), cf.target.HexShort()) } - nf.appendNewToShortlist(nf.dht.rt.GetClosest(nf.target, alpha)) - if len(nf.shortlist) == 0 { - return findNodeResponse{}, errors.Err("no nodes in routing table") + cf.appendNewToShortlist(cf.node.rt.GetClosest(cf.target, alpha)) + if len(cf.shortlist) == 0 { + return findNodeResponse{}, errors.Err("no contacts in routing table") } wg := &sync.WaitGroup{} @@ -68,163 +69,163 @@ func (nf *nodeFinder) Find() (findNodeResponse, error) { wg.Add(1) go func(i int) { defer wg.Done() - nf.iterationWorker(i + 1) + cf.iterationWorker(i + 1) }(i) } wg.Wait() - // TODO: what to do if we have less than K active nodes, shortlist is empty, but we - // TODO: have other nodes in our routing table whom we have not contacted. prolly contact them + // TODO: what to do if we have less than K active contacts, shortlist is empty, but we + // TODO: have other contacts in our routing table whom we have not contacted. prolly contact them result := findNodeResponse{} - if nf.findValue && len(nf.findValueResult) > 0 { + if cf.findValue && len(cf.findValueResult) > 0 { result.Found = true - result.Nodes = nf.findValueResult + result.Contacts = cf.findValueResult } else { - result.Nodes = nf.activeNodes - if len(result.Nodes) > bucketSize { - result.Nodes = result.Nodes[:bucketSize] + result.Contacts = cf.activeContacts + if len(result.Contacts) > bucketSize { + result.Contacts = result.Contacts[:bucketSize] } } return result, nil } -func (nf *nodeFinder) iterationWorker(num int) { - log.Debugf("[%s] starting worker %d", nf.dht.node.id.HexShort(), num) - defer func() { log.Debugf("[%s] stopping worker %d", nf.dht.node.id.HexShort(), num) }() +func (cf *contactFinder) iterationWorker(num int) { + log.Debugf("[%s] starting worker %d", cf.node.id.HexShort(), num) + defer func() { log.Debugf("[%s] stopping worker %d", cf.node.id.HexShort(), num) }() for { - maybeNode := nf.popFromShortlist() - if maybeNode == nil { + maybeContact := cf.popFromShortlist() + if maybeContact == nil { // TODO: block if there are pending requests out from other workers. there may be more shortlist values coming - log.Debugf("[%s] worker %d: no nodes in shortlist, waiting...", nf.dht.node.id.HexShort(), num) + log.Debugf("[%s] worker %d: no contacts in shortlist, waiting...", cf.node.id.HexShort(), num) time.Sleep(100 * time.Millisecond) } else { - node := *maybeNode + contact := *maybeContact - if node.id.Equals(nf.dht.node.id) { + if contact.id.Equals(cf.node.id) { continue // cannot contact self } - req := Request{Arg: &nf.target} - if nf.findValue { + req := Request{Arg: &cf.target} + if cf.findValue { req.Method = findValueMethod } else { req.Method = findNodeMethod } - log.Debugf("[%s] worker %d: contacting %s", nf.dht.node.id.HexShort(), num, node.id.HexShort()) + log.Debugf("[%s] worker %d: contacting %s", cf.node.id.HexShort(), num, contact.id.HexShort()) - nf.incrementOutstanding() + cf.incrementOutstanding() var res *Response ctx, cancel := context.WithCancel(context.Background()) - resCh := nf.dht.tm.SendAsync(ctx, node, req) + resCh := cf.node.SendAsync(ctx, contact, req) select { case res = <-resCh: - case <-nf.done.Chan(): - log.Debugf("[%s] worker %d: canceled", nf.dht.node.id.HexShort(), num) + case <-cf.done.Chan(): + log.Debugf("[%s] worker %d: canceled", cf.node.id.HexShort(), num) cancel() return } if res == nil { // nothing to do, response timed out - log.Debugf("[%s] worker %d: timed out waiting for %s", nf.dht.node.id.HexShort(), num, node.id.HexShort()) - } else if nf.findValue && res.FindValueKey != "" { - log.Debugf("[%s] worker %d: got value", nf.dht.node.id.HexShort(), num) - nf.findValueMutex.Lock() - nf.findValueResult = res.FindNodeData - nf.findValueMutex.Unlock() - nf.done.Stop() + log.Debugf("[%s] worker %d: timed out waiting for %s", cf.node.id.HexShort(), num, contact.id.HexShort()) + } else if cf.findValue && res.FindValueKey != "" { + log.Debugf("[%s] worker %d: got value", cf.node.id.HexShort(), num) + cf.findValueMutex.Lock() + cf.findValueResult = res.Contacts + cf.findValueMutex.Unlock() + cf.done.Stop() return } else { - log.Debugf("[%s] worker %d: got contacts", nf.dht.node.id.HexShort(), num) - nf.insertIntoActiveList(node) - nf.appendNewToShortlist(res.FindNodeData) + log.Debugf("[%s] worker %d: got contacts", cf.node.id.HexShort(), num) + cf.insertIntoActiveList(contact) + cf.appendNewToShortlist(res.Contacts) } - nf.decrementOutstanding() // this is all the way down here because we need to add to shortlist first + cf.decrementOutstanding() // this is all the way down here because we need to add to shortlist first } - if nf.isSearchFinished() { - log.Debugf("[%s] worker %d: search is finished", nf.dht.node.id.HexShort(), num) - nf.done.Stop() + if cf.isSearchFinished() { + log.Debugf("[%s] worker %d: search is finished", cf.node.id.HexShort(), num) + cf.done.Stop() return } } } -func (nf *nodeFinder) appendNewToShortlist(nodes []Node) { - nf.shortlistMutex.Lock() - defer nf.shortlistMutex.Unlock() +func (cf *contactFinder) appendNewToShortlist(contacts []Contact) { + cf.shortlistMutex.Lock() + defer cf.shortlistMutex.Unlock() - for _, n := range nodes { - if _, ok := nf.shortlistAdded[n.id]; !ok { - nf.shortlist = append(nf.shortlist, n) - nf.shortlistAdded[n.id] = true + for _, c := range contacts { + if _, ok := cf.shortlistAdded[c.id]; !ok { + cf.shortlist = append(cf.shortlist, c) + cf.shortlistAdded[c.id] = true } } - sortNodesInPlace(nf.shortlist, nf.target) + sortInPlace(cf.shortlist, cf.target) } -func (nf *nodeFinder) popFromShortlist() *Node { - nf.shortlistMutex.Lock() - defer nf.shortlistMutex.Unlock() +func (cf *contactFinder) popFromShortlist() *Contact { + cf.shortlistMutex.Lock() + defer cf.shortlistMutex.Unlock() - if len(nf.shortlist) == 0 { + if len(cf.shortlist) == 0 { return nil } - first := nf.shortlist[0] - nf.shortlist = nf.shortlist[1:] + first := cf.shortlist[0] + cf.shortlist = cf.shortlist[1:] return &first } -func (nf *nodeFinder) insertIntoActiveList(node Node) { - nf.activeNodesMutex.Lock() - defer nf.activeNodesMutex.Unlock() +func (cf *contactFinder) insertIntoActiveList(contact Contact) { + cf.activeContactsMutex.Lock() + defer cf.activeContactsMutex.Unlock() inserted := false - for i, n := range nf.activeNodes { - if node.id.Xor(nf.target).Less(n.id.Xor(nf.target)) { - nf.activeNodes = append(nf.activeNodes[:i], append([]Node{node}, nf.activeNodes[i:]...)...) + for i, n := range cf.activeContacts { + if contact.id.Xor(cf.target).Less(n.id.Xor(cf.target)) { + cf.activeContacts = append(cf.activeContacts[:i], append([]Contact{contact}, cf.activeContacts[i:]...)...) inserted = true break } } if !inserted { - nf.activeNodes = append(nf.activeNodes, node) + cf.activeContacts = append(cf.activeContacts, contact) } } -func (nf *nodeFinder) isSearchFinished() bool { - if nf.findValue && len(nf.findValueResult) > 0 { +func (cf *contactFinder) isSearchFinished() bool { + if cf.findValue && len(cf.findValueResult) > 0 { return true } select { - case <-nf.done.Chan(): + case <-cf.done.Chan(): return true default: } - if !nf.areRequestsOutstanding() { - nf.shortlistMutex.Lock() - defer nf.shortlistMutex.Unlock() + if !cf.areRequestsOutstanding() { + cf.shortlistMutex.Lock() + defer cf.shortlistMutex.Unlock() - if len(nf.shortlist) == 0 { + if len(cf.shortlist) == 0 { return true } - nf.activeNodesMutex.Lock() - defer nf.activeNodesMutex.Unlock() + cf.activeContactsMutex.Lock() + defer cf.activeContactsMutex.Unlock() - if len(nf.activeNodes) >= bucketSize && nf.activeNodes[bucketSize-1].id.Xor(nf.target).Less(nf.shortlist[0].id.Xor(nf.target)) { - // we have at least K active nodes, and we don't have any closer nodes yet to contact + if len(cf.activeContacts) >= bucketSize && cf.activeContacts[bucketSize-1].id.Xor(cf.target).Less(cf.shortlist[0].id.Xor(cf.target)) { + // we have at least K active contacts, and we don't have any closer contacts to ping return true } } @@ -232,20 +233,34 @@ func (nf *nodeFinder) isSearchFinished() bool { return false } -func (nf *nodeFinder) incrementOutstanding() { - nf.outstandingRequestsMutex.Lock() - defer nf.outstandingRequestsMutex.Unlock() - nf.outstandingRequests++ +func (cf *contactFinder) incrementOutstanding() { + cf.outstandingRequestsMutex.Lock() + defer cf.outstandingRequestsMutex.Unlock() + cf.outstandingRequests++ } -func (nf *nodeFinder) decrementOutstanding() { - nf.outstandingRequestsMutex.Lock() - defer nf.outstandingRequestsMutex.Unlock() - if nf.outstandingRequests > 0 { - nf.outstandingRequests-- +func (cf *contactFinder) decrementOutstanding() { + cf.outstandingRequestsMutex.Lock() + defer cf.outstandingRequestsMutex.Unlock() + if cf.outstandingRequests > 0 { + cf.outstandingRequests-- } } -func (nf *nodeFinder) areRequestsOutstanding() bool { - nf.outstandingRequestsMutex.RLock() - defer nf.outstandingRequestsMutex.RUnlock() - return nf.outstandingRequests > 0 +func (cf *contactFinder) areRequestsOutstanding() bool { + cf.outstandingRequestsMutex.RLock() + defer cf.outstandingRequestsMutex.RUnlock() + return cf.outstandingRequests > 0 +} + +func sortInPlace(contacts []Contact, target Bitmap) { + toSort := make([]sortedContact, len(contacts)) + + for i, n := range contacts { + toSort[i] = sortedContact{n, n.id.Xor(target)} + } + + sort.Sort(byXorDistance(toSort)) + + for i, c := range toSort { + contacts[i] = c.contact + } } diff --git a/dht/rpc_test.go b/dht/node_test.go similarity index 92% rename from dht/rpc_test.go rename to dht/node_test.go index 930a587..64e42a8 100644 --- a/dht/rpc_test.go +++ b/dht/node_test.go @@ -97,9 +97,11 @@ func TestPing(t *testing.T) { if err != nil { t.Fatal(err) } - dht.conn = conn - go dht.listen() - go dht.runHandler() + + err = dht.node.Connect(conn) + if err != nil { + t.Fatal(err) + } defer dht.Shutdown() messageID := newMessageID() @@ -193,9 +195,10 @@ func TestStore(t *testing.T) { t.Fatal(err) } - dht.conn = conn - go dht.listen() - go dht.runHandler() + err = dht.node.Connect(conn) + if err != nil { + t.Fatal(err) + } defer dht.Shutdown() messageID := newMessageID() @@ -208,7 +211,7 @@ func TestStore(t *testing.T) { StoreArgs: &storeArgs{ BlobHash: blobHashToStore, Value: storeArgsValue{ - Token: dht.tokens.Get(testNodeID, conn.addr), + Token: dht.node.tokens.Get(testNodeID, conn.addr), LbryID: testNodeID, Port: 9999, }, @@ -266,11 +269,11 @@ func TestStore(t *testing.T) { } } - if len(dht.store.hashes) != 1 { + if len(dht.node.store.hashes) != 1 { t.Error("dht store has wrong number of items") } - items := dht.store.Get(blobHashToStore) + items := dht.node.store.Get(blobHashToStore) if len(items) != 1 { t.Error("list created in store, but nothing in list") } @@ -289,17 +292,19 @@ func TestFindNode(t *testing.T) { if err != nil { t.Fatal(err) } - dht.conn = conn - go dht.listen() - go dht.runHandler() + + err = dht.node.Connect(conn) + if err != nil { + t.Fatal(err) + } defer dht.Shutdown() nodesToInsert := 3 - var nodes []Node + var nodes []Contact for i := 0; i < nodesToInsert; i++ { - n := Node{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} + n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} nodes = append(nodes, n) - dht.rt.Update(n) + dht.node.rt.Update(n) } messageID := newMessageID() @@ -357,17 +362,18 @@ func TestFindValueExisting(t *testing.T) { t.Fatal(err) } - dht.conn = conn - go dht.listen() - go dht.runHandler() + err = dht.node.Connect(conn) + if err != nil { + t.Fatal(err) + } defer dht.Shutdown() nodesToInsert := 3 - var nodes []Node + var nodes []Contact for i := 0; i < nodesToInsert; i++ { - n := Node{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} + n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} nodes = append(nodes, n) - dht.rt.Update(n) + dht.node.rt.Update(n) } //data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565") @@ -375,10 +381,10 @@ func TestFindValueExisting(t *testing.T) { messageID := newMessageID() valueToFind := RandomBitmapP() - nodeToFind := Node{id: RandomBitmapP(), ip: net.ParseIP("1.2.3.4"), port: 1286} - dht.store.Upsert(valueToFind, nodeToFind) - dht.store.Upsert(valueToFind, nodeToFind) - dht.store.Upsert(valueToFind, nodeToFind) + nodeToFind := Contact{id: RandomBitmapP(), ip: net.ParseIP("1.2.3.4"), port: 1286} + dht.node.store.Upsert(valueToFind, nodeToFind) + dht.node.store.Upsert(valueToFind, nodeToFind) + dht.node.store.Upsert(valueToFind, nodeToFind) request := Request{ ID: messageID, @@ -428,7 +434,7 @@ func TestFindValueExisting(t *testing.T) { t.Fatal("search results are not a list") } - verifyCompactContacts(t, contacts, []Node{nodeToFind}) + verifyCompactContacts(t, contacts, []Contact{nodeToFind}) } func TestFindValueFallbackToFindNode(t *testing.T) { @@ -442,17 +448,18 @@ func TestFindValueFallbackToFindNode(t *testing.T) { t.Fatal(err) } - dht.conn = conn - go dht.listen() - go dht.runHandler() + err = dht.node.Connect(conn) + if err != nil { + t.Fatal(err) + } defer dht.Shutdown() nodesToInsert := 3 - var nodes []Node + var nodes []Contact for i := 0; i < nodesToInsert; i++ { - n := Node{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} + n := Contact{id: RandomBitmapP(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} nodes = append(nodes, n) - dht.rt.Update(n) + dht.node.rt.Update(n) } messageID := newMessageID() @@ -557,7 +564,7 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dht } } -func verifyContacts(t *testing.T, contacts []interface{}, nodes []Node) { +func verifyContacts(t *testing.T, contacts []interface{}, nodes []Contact) { if len(contacts) != len(nodes) { t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes)) return @@ -577,7 +584,7 @@ func verifyContacts(t *testing.T, contacts []interface{}, nodes []Node) { return } - var currNode Node + var currNode Contact currNodeFound := false id, ok := contact[0].(string) @@ -618,7 +625,7 @@ func verifyContacts(t *testing.T, contacts []interface{}, nodes []Node) { } } -func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Node) { +func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Contact) { if len(contacts) != len(nodes) { t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes)) return @@ -633,14 +640,14 @@ func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Node) { return } - contact := Node{} + contact := Contact{} err := contact.UnmarshalCompact([]byte(compact)) if err != nil { t.Error(err) return } - var currNode Node + var currNode Contact currNodeFound := false if _, ok := foundNodes[contact.id.Hex()]; ok { diff --git a/dht/routing_table.go b/dht/routing_table.go index caae5c0..bf4beba 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -14,34 +14,33 @@ import ( "github.com/lyoshenka/bencode" ) -type Node struct { - id Bitmap - ip net.IP - port int - token string // this is set when the node is returned from a FindNode call +type Contact struct { + id Bitmap + ip net.IP + port int } -func (n Node) String() string { - return n.id.HexShort() + "@" + n.Addr().String() +func (c Contact) Addr() *net.UDPAddr { + return &net.UDPAddr{IP: c.ip, Port: c.port} } -func (n Node) Addr() *net.UDPAddr { - return &net.UDPAddr{IP: n.ip, Port: n.port} +func (c Contact) String() string { + return c.id.HexShort() + "@" + c.Addr().String() } -func (n Node) MarshalCompact() ([]byte, error) { - if n.ip.To4() == nil { +func (c Contact) MarshalCompact() ([]byte, error) { + if c.ip.To4() == nil { return nil, errors.Err("ip not set") } - if n.port < 0 || n.port > 65535 { + if c.port < 0 || c.port > 65535 { return nil, errors.Err("invalid port") } var buf bytes.Buffer - buf.Write(n.ip.To4()) - buf.WriteByte(byte(n.port >> 8)) - buf.WriteByte(byte(n.port)) - buf.Write(n.id[:]) + buf.Write(c.ip.To4()) + buf.WriteByte(byte(c.port >> 8)) + buf.WriteByte(byte(c.port)) + buf.Write(c.id[:]) if buf.Len() != compactNodeInfoLength { return nil, errors.Err("i dont know how this happened") @@ -50,21 +49,21 @@ func (n Node) MarshalCompact() ([]byte, error) { return buf.Bytes(), nil } -func (n *Node) UnmarshalCompact(b []byte) error { +func (c *Contact) UnmarshalCompact(b []byte) error { if len(b) != compactNodeInfoLength { return errors.Err("invalid compact length") } - n.ip = net.IPv4(b[0], b[1], b[2], b[3]).To4() - n.port = int(uint16(b[5]) | uint16(b[4])<<8) - n.id = BitmapFromBytesP(b[6:]) + c.ip = net.IPv4(b[0], b[1], b[2], b[3]).To4() + c.port = int(uint16(b[5]) | uint16(b[4])<<8) + c.id = BitmapFromBytesP(b[6:]) return nil } -func (n Node) MarshalBencode() ([]byte, error) { - return bencode.EncodeBytes([]interface{}{n.id, n.ip.String(), n.port}) +func (c Contact) MarshalBencode() ([]byte, error) { + return bencode.EncodeBytes([]interface{}{c.id, c.ip.String(), c.port}) } -func (n *Node) UnmarshalBencode(b []byte) error { +func (c *Contact) UnmarshalBencode(b []byte) error { var raw []bencode.RawMessage err := bencode.DecodeBytes(b, &raw) if err != nil { @@ -75,7 +74,7 @@ func (n *Node) UnmarshalBencode(b []byte) error { return errors.Err("contact must have 3 elements; got %d", len(raw)) } - err = bencode.DecodeBytes(raw[0], &n.id) + err = bencode.DecodeBytes(raw[0], &c.id) if err != nil { return err } @@ -85,12 +84,12 @@ func (n *Node) UnmarshalBencode(b []byte) error { if err != nil { return err } - n.ip = net.ParseIP(ipStr).To4() - if n.ip == nil { + c.ip = net.ParseIP(ipStr).To4() + if c.ip == nil { return errors.Err("invalid IP") } - err = bencode.DecodeBytes(raw[2], &n.port) + err = bencode.DecodeBytes(raw[2], &c.port) if err != nil { return err } @@ -98,12 +97,12 @@ func (n *Node) UnmarshalBencode(b []byte) error { return nil } -type sortedNode struct { - node Node +type sortedContact struct { + contact Contact xorDistanceToTarget Bitmap } -type byXorDistance []sortedNode +type byXorDistance []sortedContact func (a byXorDistance) Len() int { return len(a) } func (a byXorDistance) Swap(i, j int) { a[i], a[j] = a[j], a[i] } @@ -112,17 +111,17 @@ func (a byXorDistance) Less(i, j int) bool { } type routingTable struct { - node Node + id Bitmap buckets [numBuckets]*list.List lock *sync.RWMutex } -func newRoutingTable(node *Node) *routingTable { +func newRoutingTable(id Bitmap) *routingTable { var rt routingTable for i := range rt.buckets { rt.buckets[i] = list.New() } - rt.node = *node + rt.id = id rt.lock = &sync.RWMutex{} return &rt } @@ -131,7 +130,7 @@ func (rt *routingTable) BucketInfo() string { rt.lock.RLock() defer rt.lock.RUnlock() - bucketInfo := []string{} + var bucketInfo []string for i, b := range rt.buckets { contents := bucketContents(b) if contents != "" { @@ -152,7 +151,7 @@ func bucketContents(b *list.List) string { if ids != "" { ids += ", " } - ids += curr.Value.(Node).id.HexShort() + ids += curr.Value.(Contact).id.HexShort() } if count > 0 { @@ -162,31 +161,31 @@ func bucketContents(b *list.List) string { } } -// Update inserts or refreshes a node -func (rt *routingTable) Update(node Node) { +// Update inserts or refreshes a contact +func (rt *routingTable) Update(c Contact) { rt.lock.Lock() defer rt.lock.Unlock() - bucketNum := bucketFor(rt.node.id, node.id) + bucketNum := bucketFor(rt.id, c.id) bucket := rt.buckets[bucketNum] - element := findInList(bucket, node.id) + element := findInList(bucket, c.id) if element == nil { if bucket.Len() >= bucketSize { - // TODO: Ping front node first. Only remove if it does not respond + // TODO: Ping front contact first. Only remove if it does not respond bucket.Remove(bucket.Front()) } - bucket.PushBack(node) + bucket.PushBack(c) } else { bucket.MoveToBack(element) } } -// UpdateIfExists refreshes a node if its already in the routing table -func (rt *routingTable) UpdateIfExists(node Node) { +// UpdateIfExists refreshes a contact if its already in the routing table +func (rt *routingTable) UpdateIfExists(c Contact) { rt.lock.Lock() defer rt.lock.Unlock() - bucketNum := bucketFor(rt.node.id, node.id) + bucketNum := bucketFor(rt.id, c.id) bucket := rt.buckets[bucketNum] - element := findInList(bucket, node.id) + element := findInList(bucket, c.id) if element != nil { bucket.MoveToBack(element) } @@ -195,55 +194,55 @@ func (rt *routingTable) UpdateIfExists(node Node) { func (rt *routingTable) RemoveByID(id Bitmap) { rt.lock.Lock() defer rt.lock.Unlock() - bucketNum := bucketFor(rt.node.id, id) + bucketNum := bucketFor(rt.id, id) bucket := rt.buckets[bucketNum] - element := findInList(bucket, rt.node.id) + element := findInList(bucket, rt.id) if element != nil { bucket.Remove(element) } } -func (rt *routingTable) GetClosest(target Bitmap, limit int) []Node { +func (rt *routingTable) GetClosest(target Bitmap, limit int) []Contact { rt.lock.RLock() defer rt.lock.RUnlock() - var toSort []sortedNode + var toSort []sortedContact var bucketNum int - if rt.node.id.Equals(target) { + if rt.id.Equals(target) { bucketNum = 0 } else { - bucketNum = bucketFor(rt.node.id, target) + bucketNum = bucketFor(rt.id, target) } bucket := rt.buckets[bucketNum] - toSort = appendNodes(toSort, bucket.Front(), target) + toSort = appendContacts(toSort, bucket.Front(), target) for i := 1; (bucketNum-i >= 0 || bucketNum+i < numBuckets) && len(toSort) < limit; i++ { if bucketNum-i >= 0 { bucket = rt.buckets[bucketNum-i] - toSort = appendNodes(toSort, bucket.Front(), target) + toSort = appendContacts(toSort, bucket.Front(), target) } if bucketNum+i < numBuckets { bucket = rt.buckets[bucketNum+i] - toSort = appendNodes(toSort, bucket.Front(), target) + toSort = appendContacts(toSort, bucket.Front(), target) } } sort.Sort(byXorDistance(toSort)) - var nodes []Node - for _, c := range toSort { - nodes = append(nodes, c.node) - if len(nodes) >= limit { + var contacts []Contact + for _, sorted := range toSort { + contacts = append(contacts, sorted.contact) + if len(contacts) >= limit { break } } - return nodes + return contacts } -// Count returns the number of nodes in the routing table +// Count returns the number of contacts in the routing table func (rt *routingTable) Count() int { rt.lock.RLock() defer rt.lock.RUnlock() @@ -258,38 +257,24 @@ func (rt *routingTable) Count() int { func findInList(bucket *list.List, value Bitmap) *list.Element { for curr := bucket.Front(); curr != nil; curr = curr.Next() { - if curr.Value.(Node).id.Equals(value) { + if curr.Value.(Contact).id.Equals(value) { return curr } } return nil } -func appendNodes(nodes []sortedNode, start *list.Element, target Bitmap) []sortedNode { +func appendContacts(contacts []sortedContact, start *list.Element, target Bitmap) []sortedContact { for curr := start; curr != nil; curr = curr.Next() { - node := curr.Value.(Node) - nodes = append(nodes, sortedNode{node, node.id.Xor(target)}) + c := curr.Value.(Contact) + contacts = append(contacts, sortedContact{c, c.id.Xor(target)}) } - return nodes + return contacts } func bucketFor(id Bitmap, target Bitmap) int { if id.Equals(target) { - panic("nodes do not have a bucket for themselves") + panic("routing table does not have a bucket for its own id") } return numBuckets - 1 - target.Xor(id).PrefixLen() } - -func sortNodesInPlace(nodes []Node, target Bitmap) { - toSort := make([]sortedNode, len(nodes)) - - for i, n := range nodes { - toSort[i] = sortedNode{n, n.id.Xor(target)} - } - - sort.Sort(byXorDistance(toSort)) - - for i, c := range toSort { - nodes[i] = c.node - } -} diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go index d7094df..e4451c3 100644 --- a/dht/routing_table_test.go +++ b/dht/routing_table_test.go @@ -36,9 +36,9 @@ func TestRoutingTable(t *testing.T) { n1 := BitmapFromHexP("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") n2 := BitmapFromHexP("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") n3 := BitmapFromHexP("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") - rt := newRoutingTable(&Node{n1, net.ParseIP("127.0.0.1"), 8000, ""}) - rt.Update(Node{n2, net.ParseIP("127.0.0.1"), 8001, ""}) - rt.Update(Node{n3, net.ParseIP("127.0.0.1"), 8002, ""}) + rt := newRoutingTable(n1) + rt.Update(Contact{n2, net.ParseIP("127.0.0.1"), 8001}) + rt.Update(Contact{n3, net.ParseIP("127.0.0.1"), 8002}) contacts := rt.GetClosest(BitmapFromHexP("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1) if len(contacts) != 1 { @@ -63,14 +63,14 @@ func TestRoutingTable(t *testing.T) { } func TestCompactEncoding(t *testing.T) { - n := Node{ + c := Contact{ id: BitmapFromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"), ip: net.ParseIP("1.2.3.4"), port: int(55<<8 + 66), } var compact []byte - compact, err := n.MarshalCompact() + compact, err := c.MarshalCompact() if err != nil { t.Fatal(err) } @@ -79,7 +79,7 @@ func TestCompactEncoding(t *testing.T) { t.Fatalf("got length of %d; expected %d", len(compact), compactNodeInfoLength) } - if !reflect.DeepEqual(compact, append([]byte{1, 2, 3, 4, 55, 66}, n.id[:]...)) { + if !reflect.DeepEqual(compact, append([]byte{1, 2, 3, 4, 55, 66}, c.id[:]...)) { t.Errorf("compact bytes not encoded correctly") } } diff --git a/dht/rpc.go b/dht/rpc.go deleted file mode 100644 index b0f463d..0000000 --- a/dht/rpc.go +++ /dev/null @@ -1,178 +0,0 @@ -package dht - -import ( - "encoding/hex" - "net" - "time" - - "github.com/lbryio/errors.go" - "github.com/lbryio/lbry.go/util" - - "github.com/davecgh/go-spew/spew" - "github.com/lyoshenka/bencode" - log "github.com/sirupsen/logrus" -) - -// handlePacket handles packets received from udp. -func handlePacket(dht *DHT, pkt packet) { - //log.Debugf("[%s] Received message from %s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), len(pkt.data), hex.EncodeToString(pkt.data)) - - if !util.InSlice(string(pkt.data[0:5]), []string{"d1:0i", "di0ei"}) { - log.Errorf("[%s] data is not a well-formatted dict: (%d bytes) %s", dht.node.id.HexShort(), len(pkt.data), hex.EncodeToString(pkt.data)) - return - } - - // TODO: test this stuff more thoroughly - - // the following is a bit of a hack, but it lets us avoid decoding every message twice - // it depends on the data being a dict with 0 as the first key (so it starts with "d1:0i") and the message type as the first value - - switch pkt.data[5] { - case '0' + requestType: - request := Request{} - err := bencode.DecodeBytes(pkt.data, &request) - if err != nil { - log.Errorf("[%s] error decoding request from %s: %s: (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) - return - } - log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, request.ArgsDebug()) - handleRequest(dht, pkt.raddr, request) - - case '0' + responseType: - response := Response{} - err := bencode.DecodeBytes(pkt.data, &response) - if err != nil { - log.Errorf("[%s] error decoding response from %s: %s: (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) - return - } - log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug()) - handleResponse(dht, pkt.raddr, response) - - case '0' + errorType: - e := Error{} - err := bencode.DecodeBytes(pkt.data, &e) - if err != nil { - log.Errorf("[%s] error decoding error from %s: %s: (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.String(), err.Error(), len(pkt.data), hex.EncodeToString(pkt.data)) - return - } - log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType) - handleError(dht, pkt.raddr, e) - - default: - log.Errorf("[%s] invalid message type: %s", dht.node.id.HexShort(), pkt.data[5]) - return - } -} - -// handleRequest handles the requests received from udp. -func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { - if request.NodeID.Equals(dht.node.id) { - log.Warn("ignoring self-request") - return - } - - switch request.Method { - default: - // send(dht, addr, makeError(t, protocolError, "invalid q")) - log.Errorln("invalid request method") - return - case pingMethod: - send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse}) - case storeMethod: - // TODO: we should be sending the IP in the request, not just using the sender's IP - // TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ??? - if dht.tokens.Verify(request.StoreArgs.Value.Token, request.NodeID, addr) { - dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port}) - send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse}) - } else { - send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id, ExceptionType: "invalid-token"}) - } - case findNodeMethod: - if request.Arg == nil { - log.Errorln("request is missing arg") - return - } - send(dht, addr, getFindResponse(dht, request)) - - case findValueMethod: - if request.Arg == nil { - log.Errorln("request is missing arg") - return - } - - if nodes := dht.store.Get(*request.Arg); len(nodes) > 0 { - send(dht, addr, Response{ - ID: request.ID, - NodeID: dht.node.id, - FindValueKey: request.Arg.RawString(), - FindNodeData: nodes, - Token: dht.tokens.Get(request.NodeID, addr), - }) - } else { - res := getFindResponse(dht, request) - res.Token = dht.tokens.Get(request.NodeID, addr) - send(dht, addr, res) - } - } - - // nodes that send us requests should not be inserted, only refreshed. - // the routing table must only contain "good" nodes, which are nodes that reply to our requests - // if a node is already good (aka in the table), its fine to refresh it - // http://www.bittorrent.org/beps/bep_0005.html#routing-table - node := Node{id: request.NodeID, ip: addr.IP, port: addr.Port} - dht.rt.UpdateIfExists(node) -} - -func getFindResponse(dht *DHT, request Request) Response { - closestNodes := dht.rt.GetClosest(*request.Arg, bucketSize) - response := Response{ - ID: request.ID, - NodeID: dht.node.id, - FindNodeData: make([]Node, len(closestNodes)), - } - for i, n := range closestNodes { - response.FindNodeData[i] = n - } - return response -} - -// handleResponse handles responses received from udp. -func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) { - tx := dht.tm.Find(response.ID, addr) - if tx != nil { - tx.res <- response - } - - node := Node{id: response.NodeID, ip: addr.IP, port: addr.Port} - dht.rt.Update(node) -} - -// handleError handles errors received from udp. -func handleError(dht *DHT, addr *net.UDPAddr, e Error) { - spew.Dump(e) - node := Node{id: e.NodeID, ip: addr.IP, port: addr.Port} - dht.rt.UpdateIfExists(node) -} - -// send sends data to a udp address -func send(dht *DHT, addr *net.UDPAddr, data Message) error { - encoded, err := bencode.EncodeBytes(data) - if err != nil { - return errors.Err(err) - } - - if req, ok := data.(Request); ok { - log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)", - dht.node.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, req.ArgsDebug()) - } else if res, ok := data.(Response); ok { - log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s", - dht.node.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug()) - } else { - log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data)) - } - - dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15)) - - _, err = dht.conn.WriteToUDP(encoded, addr) - return errors.Err(err) -} diff --git a/dht/store.go b/dht/store.go index 794e655..76f33fb 100644 --- a/dht/store.go +++ b/dht/store.go @@ -3,7 +3,7 @@ package dht import "sync" type peer struct { - node Node + contact Contact //, // // @@ -12,42 +12,48 @@ type peer struct { type peerStore struct { // map of blob hashes to (map of node IDs to bools) hashes map[Bitmap]map[Bitmap]bool - // map of node IDs to peers - nodeInfo map[Bitmap]peer - lock sync.RWMutex + // stores the peers themselves, so they can be updated in one place + peers map[Bitmap]peer + lock sync.RWMutex } func newPeerStore() *peerStore { return &peerStore{ - hashes: make(map[Bitmap]map[Bitmap]bool), - nodeInfo: make(map[Bitmap]peer), + hashes: make(map[Bitmap]map[Bitmap]bool), + peers: make(map[Bitmap]peer), } } -func (s *peerStore) Upsert(blobHash Bitmap, node Node) { +func (s *peerStore) Upsert(blobHash Bitmap, contact Contact) { s.lock.Lock() defer s.lock.Unlock() + if _, ok := s.hashes[blobHash]; !ok { s.hashes[blobHash] = make(map[Bitmap]bool) } - s.hashes[blobHash][node.id] = true - s.nodeInfo[node.id] = peer{node: node} + s.hashes[blobHash][contact.id] = true + s.peers[contact.id] = peer{contact: contact} } -func (s *peerStore) Get(blobHash Bitmap) []Node { +func (s *peerStore) Get(blobHash Bitmap) []Contact { s.lock.RLock() defer s.lock.RUnlock() - var nodes []Node + + var contacts []Contact if ids, ok := s.hashes[blobHash]; ok { for id := range ids { - peer, ok := s.nodeInfo[id] + peer, ok := s.peers[id] if !ok { panic("node id in IDs list, but not in nodeInfo") } - nodes = append(nodes, peer.node) + contacts = append(contacts, peer.contact) } } - return nodes + return contacts +} + +func (s *peerStore) RemoveTODO(contact Contact) { + // TODO: remove peer from everywhere } func (s *peerStore) CountStoredHashes() int { diff --git a/dht/transaction_manager.go b/dht/transaction_manager.go deleted file mode 100644 index 4141bbf..0000000 --- a/dht/transaction_manager.go +++ /dev/null @@ -1,125 +0,0 @@ -package dht - -import ( - "context" - "net" - "strings" - "sync" - "time" - - log "github.com/sirupsen/logrus" -) - -// transaction represents a single query to the dht. it stores the queried node, the request, and the response channel -type transaction struct { - node Node - req Request - res chan Response -} - -// transactionManager keeps track of the outstanding transactions -type transactionManager struct { - dht *DHT - lock *sync.RWMutex - transactions map[messageID]*transaction -} - -// newTransactionManager returns a new transactionManager -func newTransactionManager(dht *DHT) *transactionManager { - return &transactionManager{ - lock: &sync.RWMutex{}, - transactions: make(map[messageID]*transaction), - dht: dht, - } -} - -// insert adds a transaction to the manager. -func (tm *transactionManager) insert(tx *transaction) { - tm.lock.Lock() - defer tm.lock.Unlock() - tm.transactions[tx.req.ID] = tx -} - -// delete removes a transaction from the manager. -func (tm *transactionManager) delete(id messageID) { - tm.lock.Lock() - defer tm.lock.Unlock() - delete(tm.transactions, id) -} - -// Find finds a transaction for the given id. it optionally ensures that addr matches node from transaction -func (tm *transactionManager) Find(id messageID, addr *net.UDPAddr) *transaction { - tm.lock.RLock() - defer tm.lock.RUnlock() - - // TODO: also check that the response's nodeid matches the id you thought you sent to? - - t, ok := tm.transactions[id] - if !ok || (addr != nil && t.node.Addr().String() != addr.String()) { - return nil - } - - return t -} - -// SendAsync sends a transaction and returns a channel that will eventually contain the transaction response -// The response channel is closed when the transaction is completed or times out. -func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req Request) <-chan *Response { - if node.id.Equals(tm.dht.node.id) { - log.Error("sending query to self") - return nil - } - - ch := make(chan *Response, 1) - - go func() { - defer close(ch) - - req.ID = newMessageID() - req.NodeID = tm.dht.node.id - tx := &transaction{ - node: node, - req: req, - res: make(chan Response), - } - - tm.insert(tx) - defer tm.delete(tx.req.ID) - - for i := 0; i < udpRetry; i++ { - if err := send(tm.dht, node.Addr(), tx.req); err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") { // this only happens on localhost. real UDP has no connections - log.Error("send error: ", err) - } - continue // try again? return? - } - - select { - case res := <-tx.res: - ch <- &res - return - case <-ctx.Done(): - return - case <-time.After(udpTimeout): - } - } - - // if request timed out each time - tm.dht.rt.RemoveByID(tx.node.id) - }() - - return ch -} - -// Send sends a transaction and blocks until the response is available. It returns a response, or nil -// if the transaction timed out. -func (tm *transactionManager) Send(node Node, req Request) *Response { - return <-tm.SendAsync(context.Background(), node, req) -} - -// Count returns the number of transactions in the manager -func (tm *transactionManager) Count() int { - tm.lock.Lock() - defer tm.lock.Unlock() - return len(tm.transactions) -}