diff --git a/cmd/dht.go b/cmd/dht.go new file mode 100644 index 0000000..0c73d07 --- /dev/null +++ b/cmd/dht.go @@ -0,0 +1,30 @@ +package cmd + +import ( + "github.com/lbryio/reflector.go/dht" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +func init() { + var cmd = &cobra.Command{ + Use: "dht", + Short: "Run interactive dht node", + Run: dhtCmd, + } + RootCmd.AddCommand(cmd) +} + +func dhtCmd(cmd *cobra.Command, args []string) { + dht, err := dht.New(&dht.Config{ + Address: "127.0.0.1:21216", + SeedNodes: []string{"127.0.0.1:21215"}, + PrintState: true, + }) + if err != nil { + log.Fatal(err) + } + + dht.Run() +} diff --git a/dht/bitmap.go b/dht/bitmap.go index 90f83fa..a1236ca 100644 --- a/dht/bitmap.go +++ b/dht/bitmap.go @@ -1,11 +1,11 @@ package dht import ( + "crypto/rand" "encoding/hex" - "math/rand" "strconv" - "github.com/zeebo/bencode" + "github.com/lyoshenka/bencode" ) type bitmap [nodeIDLength]byte @@ -45,7 +45,7 @@ func (b bitmap) Xor(other bitmap) bitmap { } // PrefixLen returns the number of leading 0 bits -func (b bitmap) PrefixLen() (ret int) { +func (b bitmap) PrefixLen() int { for i := range b { for j := 0; j < 8; j++ { if (b[i]>>uint8(7-j))&0x1 != 0 { @@ -95,8 +95,9 @@ func newBitmapFromHex(hexStr string) bitmap { func newRandomBitmap() bitmap { var id bitmap - for k := range id { - id[k] = uint8(rand.Intn(256)) + _, err := rand.Read(id[:]) + if err != nil { + panic(err) } return id } diff --git a/dht/bitmap_test.go b/dht/bitmap_test.go index 16a9bec..b1ca654 100644 --- a/dht/bitmap_test.go +++ b/dht/bitmap_test.go @@ -3,7 +3,7 @@ package dht import ( "testing" - "github.com/zeebo/bencode" + "github.com/lyoshenka/bencode" ) func TestBitmap(t *testing.T) { diff --git a/dht/conn.go b/dht/conn.go deleted file mode 100644 index 3704a70..0000000 --- a/dht/conn.go +++ /dev/null @@ -1,60 +0,0 @@ -package dht - -import ( - "net" - "strconv" - "strings" - "time" -) - -type UDPConn interface { - ReadFromUDP([]byte) (int, *net.UDPAddr, error) - WriteToUDP([]byte, *net.UDPAddr) (int, error) - SetWriteDeadline(time.Time) error -} - -type testUDPPacket struct { - data []byte - addr *net.UDPAddr -} - -type testUDPConn struct { - addr *net.UDPAddr - toRead chan testUDPPacket - writes chan testUDPPacket -} - -func newTestUDPConn(addr string) *testUDPConn { - parts := strings.Split(addr, ":") - if len(parts) != 2 { - panic("addr needs ip and port") - } - port, err := strconv.Atoi(parts[1]) - if err != nil { - panic(err) - } - return &testUDPConn{ - addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port}, - toRead: make(chan testUDPPacket), - writes: make(chan testUDPPacket), - } -} - -func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { - select { - case packet := <-t.toRead: - n := copy(b, packet.data) - return n, packet.addr, nil - //default: - // return 0, nil, nil - } -} - -func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { - t.writes <- testUDPPacket{data: b, addr: addr} - return len(b), nil -} - -func (t testUDPConn) SetWriteDeadline(tm time.Time) error { - return nil -} diff --git a/dht/decode_test.go b/dht/decode_test.go index 70feab0..d7c0764 100644 --- a/dht/decode_test.go +++ b/dht/decode_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/davecgh/go-spew/spew" - "github.com/zeebo/bencode" + "github.com/lyoshenka/bencode" ) func TestDecode(t *testing.T) { diff --git a/dht/dht.go b/dht/dht.go index 03de2fa..66abd66 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -1,23 +1,24 @@ package dht import ( - "encoding/hex" "net" - "reflect" - "strings" "time" - "github.com/davecgh/go-spew/spew" + "github.com/lbryio/errors.go" + log "github.com/sirupsen/logrus" "github.com/spf13/cast" - "github.com/zeebo/bencode" ) const network = "udp4" -const alpha = 3 // this is the constant alpha in the spec -const nodeIDLength = 48 // bytes. this is the constant B in the spec -const bucketSize = 8 // this is the constant k in the spec +const alpha = 3 // this is the constant alpha in the spec +const nodeIDLength = 48 // bytes. this is the constant B in the spec +const messageIDLength = 20 // bytes. +const bucketSize = 8 // this is the constant k in the spec + +const udpRetry = 3 +const udpTimeout = 10 * time.Second const tExpire = 86400 * time.Second // the time after which a key/value pair expires; this is a time-to-live (TTL) from the original publication date const tRefresh = 3600 * time.Second // the time after which an otherwise unaccessed bucket must be refreshed @@ -41,6 +42,8 @@ type Config struct { SeedNodes []string // the hex-encoded node id for this node. if string is empty, a random id will be generated NodeID string + // print the state of the dht every minute + PrintState bool } // NewStandardConfig returns a Config pointer with default values. @@ -55,18 +58,26 @@ func NewStandardConfig() *Config { } } +// UDPConn allows using a mocked connection for testing sending/receiving data +type UDPConn interface { + ReadFromUDP([]byte) (int, *net.UDPAddr, error) + WriteToUDP([]byte, *net.UDPAddr) (int, error) + SetWriteDeadline(time.Time) error +} + // DHT represents a DHT node. type DHT struct { - conf *Config - conn UDPConn - node *Node - routingTable *RoutingTable - packets chan packet - store *peerStore + conf *Config + conn UDPConn + node *Node + rt *RoutingTable + packets chan packet + store *peerStore + tm *transactionManager } // New returns a DHT pointer. If config is nil, then config will be set to the default config. -func New(config *Config) *DHT { +func New(config *Config) (*DHT, error) { if config == nil { config = NewStandardConfig() } @@ -80,41 +91,51 @@ func New(config *Config) *DHT { ip, port, err := net.SplitHostPort(config.Address) if err != nil { - panic(err) + return nil, errors.Err(err) } else if ip == "" { - panic("address does not contain an IP") + return nil, errors.Err("address does not contain an IP") } else if port == "" { - panic("address does not contain a port") + return nil, errors.Err("address does not contain a port") } portInt, err := cast.ToIntE(port) if err != nil { - panic(err) + return nil, errors.Err(err) } node := &Node{id: id, ip: net.ParseIP(ip), port: portInt} if node.ip == nil { - panic("invalid ip") + return nil, errors.Err("invalid ip") } - return &DHT{ - conf: config, - node: node, - routingTable: newRoutingTable(node), - packets: make(chan packet), - store: newPeerStore(), + + d := &DHT{ + conf: config, + node: node, + rt: newRoutingTable(node), + packets: make(chan packet), + store: newPeerStore(), } + d.tm = newTransactionManager(d) + return d, nil } // init initializes global variables. -func (dht *DHT) init() { +func (dht *DHT) init() error { log.Info("Initializing DHT on " + dht.conf.Address) log.Infof("Node ID is %s", dht.node.id.Hex()) + listener, err := net.ListenPacket(network, dht.conf.Address) if err != nil { - panic(err) + return errors.Err(err) } dht.conn = listener.(*net.UDPConn) + + if dht.conf.PrintState { + go printState(dht) + } + + return nil } // listen receives message from udp. @@ -159,201 +180,98 @@ func (dht *DHT) runHandler() { for { select { case pkt = <-dht.packets: - handle(dht, pkt) + handlePacket(dht, pkt) } } } // Run starts the dht. -func (dht *DHT) Run() { - dht.init() +func (dht *DHT) Run() error { + err := dht.init() + if err != nil { + return err + } + dht.listen() dht.join() log.Info("DHT ready") dht.runHandler() + return nil } -// handle handles packets received from udp. -func handle(dht *DHT, pkt packet) { - //log.Infof("Received message from %s:%s : %s\n", pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data)) - - var data map[string]interface{} - err := bencode.DecodeBytes(pkt.data, &data) - if err != nil { - log.Errorf("error decoding data: %s\n%s", err, pkt.data) - return - } - - msgType, ok := data[headerTypeField] - if !ok { - log.Errorf("decoded data has no message type: %s", data) - return - } - - switch msgType.(int64) { - case requestType: - request := Request{} - err = bencode.DecodeBytes(pkt.data, &request) - if err != nil { - log.Errorln(err) - return - } - log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args)) - handleRequest(dht, pkt.raddr, request) - - case responseType: - response := Response{} - err = bencode.DecodeBytes(pkt.data, &response) - if err != nil { - return - } - log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.Data) - handleResponse(dht, pkt.raddr, response) - - case errorType: - e := Error{ - ID: data[headerMessageIDField].(string), - NodeID: data[headerNodeIDField].(string), - ExceptionType: data[headerPayloadField].(string), - Response: getArgs(data[headerArgsField]), - } - log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType) - handleError(dht, pkt.raddr, e) - - default: - log.Errorf("Invalid message type: %s", msgType) - return +func printState(dht *DHT) { + t := time.NewTicker(60 * time.Second) + for { + log.Printf("DHT state at %s", time.Now().Format(time.RFC822Z)) + log.Printf("Outstanding transactions: %d", dht.tm.Count()) + log.Printf("Known nodes: %d", dht.store.CountKnownNodes()) + log.Printf("Buckets: \n%s", dht.rt.BucketInfo()) + <-t.C } } -// handleRequest handles the requests received from udp. -func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { - if request.NodeID == dht.node.id.RawString() { - log.Warn("ignoring self-request") - return - } - - switch request.Method { - case pingMethod: - send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse}) - case storeMethod: - if request.StoreArgs.BlobHash == "" { - log.Errorln("blobhash is empty") - return // nothing to store - } - // TODO: we should be sending the IP in the request, not just using the sender's IP - // TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ??? - dht.store.Insert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port}) - send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse}) - case findNodeMethod: - log.Println("findnode") - if len(request.Args) < 1 { - log.Errorln("nothing to find") - return - } - if len(request.Args[0]) != nodeIDLength { - log.Errorln("invalid node id") - return - } - doFindNodes(dht, addr, request) - case findValueMethod: - log.Println("findvalue") - if len(request.Args) < 1 { - log.Errorln("nothing to find") - return - } - if len(request.Args[0]) != nodeIDLength { - log.Errorln("invalid node id") - return - } - - if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 { - response := Response{ID: request.ID, NodeID: dht.node.id.RawString()} - response.FindValueKey = request.Args[0] - response.FindNodeData = nodes - send(dht, addr, response) - } else { - doFindNodes(dht, addr, request) - } - - default: - // send(dht, addr, makeError(t, protocolError, "invalid q")) - log.Errorln("invalid request method") - return - } - - node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port} - dht.routingTable.Update(node) -} - -func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) { - nodeID := newBitmapFromString(request.Args[0]) - closestNodes := dht.routingTable.FindClosest(nodeID, bucketSize) - if len(closestNodes) > 0 { - response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))} - for i, n := range closestNodes { - response.FindNodeData[i] = *n - } - send(dht, addr, response) - } -} - -// handleResponse handles responses received from udp. -func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) { - spew.Dump(response) - - // TODO: find transaction by message id, pass along response - - node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port} - dht.routingTable.Update(node) -} - -// handleError handles errors received from udp. -func handleError(dht *DHT, addr *net.UDPAddr, e Error) { - spew.Dump(e) - node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port} - dht.routingTable.Update(node) -} - -// send sends data to the udp. -func send(dht *DHT, addr *net.UDPAddr, data Message) error { - if req, ok := data.(Request); ok { - log.Debugf("[%s] query %s: sending request: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], req.Method, argsToString(req.Args)) - } else if res, ok := data.(Response); ok { - log.Debugf("[%s] query %s: sending response: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], spew.Sdump(res.Data)) - } else { - log.Debugf("[%s] %s", spew.Sdump(data)) - } - encoded, err := bencode.EncodeBytes(data) - if err != nil { - return err - } - //log.Infof("Encoded: %s", string(encoded)) - - dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15)) - - _, err = dht.conn.WriteToUDP(encoded, addr) - return err -} - -func getArgs(argsInt interface{}) []string { - var args []string - if reflect.TypeOf(argsInt).Kind() == reflect.Slice { - v := reflect.ValueOf(argsInt) - for i := 0; i < v.Len(); i++ { - args = append(args, cast.ToString(v.Index(i).Interface())) - } - } - return args -} - -func argsToString(args []string) string { - argsCopy := make([]string, len(args)) - copy(argsCopy, args) - for k, v := range argsCopy { - if len(v) == nodeIDLength { - argsCopy[k] = hex.EncodeToString([]byte(v))[:8] - } - } - return strings.Join(argsCopy, ", ") -} +//func (dht *DHT) Get(hash bitmap) ([]Node, error) { +// return iterativeFindNode(dht, hash) +//} +// +//func iterativeFindNode(dht *DHT, hash bitmap) ([]Node, error) { +// shortlist := dht.rt.FindClosest(hash, alpha) +// if len(shortlist) == 0 { +// return nil, errors.Err("no nodes in routing table") +// } +// +// pending := make(chan *Node) +// contacted := make(map[bitmap]bool) +// contactedMutex := &sync.Mutex{} +// closestNodeMutex := &sync.Mutex{} +// closestNode := shortlist[0] +// wg := sync.WaitGroup{} +// +// for i := 0; i < alpha; i++ { +// wg.Add(1) +// go func() { +// defer wg.Done() +// for { +// node, ok := <-pending +// if !ok { +// return +// } +// +// contactedMutex.Lock() +// if _, ok := contacted[node.id]; ok { +// contactedMutex.Unlock() +// continue +// } +// contacted[node.id] = true +// contactedMutex.Unlock() +// +// res := dht.tm.Send(node, &Request{ +// NodeID: dht.node.id.RawString(), +// Method: findNodeMethod, +// Args: []string{hash.RawString()}, +// }) +// if res == nil { +// // remove node from shortlist +// continue +// } +// +// for _, n := range res.FindNodeData { +// pending <- &n +// closestNodeMutex.Lock() +// if n.id.Xor(hash).Less(closestNode.id.Xor(hash)) { +// closestNode = &n +// } +// closestNodeMutex.Unlock() +// } +// } +// }() +// } +// +// for _, n := range shortlist { +// pending <- n +// } +// +// wg.Wait() +// +// return nil, nil +//} diff --git a/dht/message.go b/dht/message.go index e7cb440..f826e0d 100644 --- a/dht/message.go +++ b/dht/message.go @@ -1,10 +1,12 @@ package dht import ( + "encoding/hex" + "github.com/lbryio/errors.go" + "github.com/lyoshenka/bencode" "github.com/spf13/cast" - "github.com/zeebo/bencode" ) const ( @@ -171,6 +173,21 @@ type Response struct { FindValueKey string } +func (r Response) ArgsDebug() string { + if len(r.FindNodeData) == 0 { + return r.Data + } + + str := "contacts " + if r.FindValueKey != "" { + str += "for " + hex.EncodeToString([]byte(r.FindValueKey))[:8] + " " + } + for _, c := range r.FindNodeData { + str += c.Addr().String() + ":" + c.id.Hex()[:8] + ", " + } + return str[:len(str)-2] // chomp off last ", " +} + func (r Response) MarshalBencode() ([]byte, error) { data := map[string]interface{}{ headerTypeField: responseType, diff --git a/dht/message_test.go b/dht/message_test.go index b03b46d..4dc4367 100644 --- a/dht/message_test.go +++ b/dht/message_test.go @@ -7,8 +7,8 @@ import ( "strings" "testing" + "github.com/lyoshenka/bencode" log "github.com/sirupsen/logrus" - "github.com/zeebo/bencode" ) func TestBencodeDecodeStoreArgs(t *testing.T) { diff --git a/dht/routing_table.go b/dht/routing_table.go index 138d310..a433ac2 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -3,12 +3,15 @@ package dht import ( "bytes" "container/list" + "fmt" "net" "sort" + "strings" + "sync" "github.com/lbryio/errors.go" - "github.com/zeebo/bencode" + "github.com/lyoshenka/bencode" ) type Node struct { @@ -17,6 +20,10 @@ type Node struct { port int } +func (n Node) Addr() *net.UDPAddr { + return &net.UDPAddr{IP: n.ip, Port: n.port} +} + func (n Node) MarshalCompact() ([]byte, error) { if n.ip.To4() == nil { return nil, errors.Err("ip not set") @@ -102,6 +109,7 @@ func (a byXorDistance) Less(i, j int) bool { type RoutingTable struct { node Node buckets [numBuckets]*list.List + lock *sync.RWMutex } func newRoutingTable(node *Node) *RoutingTable { @@ -110,39 +118,73 @@ func newRoutingTable(node *Node) *RoutingTable { rt.buckets[i] = list.New() } rt.node = *node + rt.lock = &sync.RWMutex{} return &rt } +func (rt *RoutingTable) BucketInfo() string { + rt.lock.RLock() + defer rt.lock.RUnlock() + + bucketInfo := []string{} + for i, b := range rt.buckets { + count := countInList(b) + if count > 0 { + bucketInfo = append(bucketInfo, fmt.Sprintf("Bucket %d: %d", i, count)) + } + } + if len(bucketInfo) == 0 { + return "buckets are empty" + } + return strings.Join(bucketInfo, "\n") +} + func (rt *RoutingTable) Update(node *Node) { - prefixLength := node.id.Xor(rt.node.id).PrefixLen() - bucket := rt.buckets[prefixLength] + rt.lock.Lock() + defer rt.lock.Unlock() + bucketNum := bucketFor(rt.node.id, node.id) + bucket := rt.buckets[bucketNum] element := findInList(bucket, rt.node.id) if element == nil { - if bucket.Len() <= bucketSize { - bucket.PushBack(node) + if bucket.Len() >= bucketSize { + // TODO: Ping front node first. Only remove if it does not respond + bucket.Remove(bucket.Front()) } - // TODO: Handle insertion when the list is full by evicting old elements if - // they don't respond to a ping. + bucket.PushBack(node) } else { bucket.MoveToBack(element) } } -func (rt *RoutingTable) FindClosest(target bitmap, count int) []*Node { +func (rt *RoutingTable) RemoveByID(id bitmap) { + rt.lock.Lock() + defer rt.lock.Unlock() + bucketNum := bucketFor(rt.node.id, id) + bucket := rt.buckets[bucketNum] + element := findInList(bucket, rt.node.id) + if element != nil { + bucket.Remove(element) + } +} + +func (rt *RoutingTable) FindClosest(target bitmap, limit int) []*Node { + rt.lock.RLock() + defer rt.lock.RUnlock() + var toSort []*SortedNode - prefixLength := target.Xor(rt.node.id).PrefixLen() - bucket := rt.buckets[prefixLength] - toSort = appendNodes(toSort, bucket.Front(), nil, target) + bucketNum := bucketFor(rt.node.id, target) + bucket := rt.buckets[bucketNum] + toSort = appendNodes(toSort, bucket.Front(), target) - for i := 1; (prefixLength-i >= 0 || prefixLength+i < numBuckets) && len(toSort) < count; i++ { - if prefixLength-i >= 0 { - bucket = rt.buckets[prefixLength-i] - toSort = appendNodes(toSort, bucket.Front(), nil, target) + for i := 1; (bucketNum-i >= 0 || bucketNum+i < numBuckets) && len(toSort) < limit; i++ { + if bucketNum-i >= 0 { + bucket = rt.buckets[bucketNum-i] + toSort = appendNodes(toSort, bucket.Front(), target) } - if prefixLength+i < numBuckets { - bucket = rt.buckets[prefixLength+i] - toSort = appendNodes(toSort, bucket.Front(), nil, target) + if bucketNum+i < numBuckets { + bucket = rt.buckets[bucketNum+i] + toSort = appendNodes(toSort, bucket.Front(), target) } } @@ -151,6 +193,9 @@ func (rt *RoutingTable) FindClosest(target bitmap, count int) []*Node { var nodes []*Node for _, c := range toSort { nodes = append(nodes, c.node) + if len(nodes) >= limit { + break + } } return nodes @@ -165,10 +210,25 @@ func findInList(bucket *list.List, value bitmap) *list.Element { return nil } -func appendNodes(nodes []*SortedNode, start, end *list.Element, target bitmap) []*SortedNode { - for curr := start; curr != end; curr = curr.Next() { +func countInList(bucket *list.List) int { + count := 0 + for curr := bucket.Front(); curr != nil; curr = curr.Next() { + count++ + } + return count +} + +func appendNodes(nodes []*SortedNode, start *list.Element, target bitmap) []*SortedNode { + for curr := start; curr != nil; curr = curr.Next() { node := curr.Value.(*Node) nodes = append(nodes, &SortedNode{node, node.id.Xor(target)}) } return nodes } + +func bucketFor(id bitmap, target bitmap) int { + if id.Equals(target) { + panic("nodes do not have a bucket for themselves") + } + return numBuckets - 1 - target.Xor(id).PrefixLen() +} diff --git a/dht/rpc.go b/dht/rpc.go new file mode 100644 index 0000000..35db813 --- /dev/null +++ b/dht/rpc.go @@ -0,0 +1,210 @@ +package dht + +import ( + "crypto/rand" + "encoding/hex" + "net" + "reflect" + "strings" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/lyoshenka/bencode" + log "github.com/sirupsen/logrus" + "github.com/spf13/cast" +) + +func newMessageID() string { + buf := make([]byte, messageIDLength) + _, err := rand.Read(buf) + if err != nil { + panic(err) + } + return string(buf) +} + +// handlePacke handles packets received from udp. +func handlePacket(dht *DHT, pkt packet) { + //log.Infof("Received message from %s:%s : %s\n", pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data)) + + var data map[string]interface{} + err := bencode.DecodeBytes(pkt.data, &data) + if err != nil { + log.Errorf("error decoding data: %s\n%s", err, pkt.data) + return + } + + msgType, ok := data[headerTypeField] + if !ok { + log.Errorf("decoded data has no message type: %s", data) + return + } + + switch msgType.(int64) { + case requestType: + request := Request{} + err = bencode.DecodeBytes(pkt.data, &request) + if err != nil { + log.Errorln(err) + return + } + log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args)) + handleRequest(dht, pkt.raddr, request) + + case responseType: + response := Response{} + err = bencode.DecodeBytes(pkt.data, &response) + if err != nil { + return + } + log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.Data) + handleResponse(dht, pkt.raddr, response) + + case errorType: + e := Error{ + ID: data[headerMessageIDField].(string), + NodeID: data[headerNodeIDField].(string), + ExceptionType: data[headerPayloadField].(string), + Response: getArgs(data[headerArgsField]), + } + log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType) + handleError(dht, pkt.raddr, e) + + default: + log.Errorf("Invalid message type: %s", msgType) + return + } +} + +// handleRequest handles the requests received from udp. +func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { + if request.NodeID == dht.node.id.RawString() { + log.Warn("ignoring self-request") + return + } + + switch request.Method { + case pingMethod: + send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse}) + case storeMethod: + if request.StoreArgs.BlobHash == "" { + log.Errorln("blobhash is empty") + return // nothing to store + } + // TODO: we should be sending the IP in the request, not just using the sender's IP + // TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ??? + dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port}) + send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse}) + case findNodeMethod: + if len(request.Args) < 1 { + log.Errorln("nothing to find") + return + } + if len(request.Args[0]) != nodeIDLength { + log.Errorln("invalid node id") + return + } + doFindNodes(dht, addr, request) + case findValueMethod: + if len(request.Args) < 1 { + log.Errorln("nothing to find") + return + } + if len(request.Args[0]) != nodeIDLength { + log.Errorln("invalid node id") + return + } + + if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 { + response := Response{ID: request.ID, NodeID: dht.node.id.RawString()} + response.FindValueKey = request.Args[0] + response.FindNodeData = nodes + send(dht, addr, response) + } else { + doFindNodes(dht, addr, request) + } + + default: + // send(dht, addr, makeError(t, protocolError, "invalid q")) + log.Errorln("invalid request method") + return + } + + node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port} + dht.rt.Update(node) +} + +func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) { + nodeID := newBitmapFromString(request.Args[0]) + closestNodes := dht.rt.FindClosest(nodeID, bucketSize) + if len(closestNodes) > 0 { + response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))} + for i, n := range closestNodes { + response.FindNodeData[i] = *n + } + send(dht, addr, response) + } else { + log.Warn("no nodes in routing table") + } +} + +// handleResponse handles responses received from udp. +func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) { + tx := dht.tm.Find(response.ID, addr) + if tx != nil { + tx.res <- &response + } + + node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port} + dht.rt.Update(node) +} + +// handleError handles errors received from udp. +func handleError(dht *DHT, addr *net.UDPAddr, e Error) { + spew.Dump(e) + node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port} + dht.rt.Update(node) +} + +// send sends data to the udp. +func send(dht *DHT, addr *net.UDPAddr, data Message) error { + if req, ok := data.(Request); ok { + log.Debugf("[%s] query %s: sending request to %s : %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], addr.String(), req.Method, argsToString(req.Args)) + } else if res, ok := data.(Response); ok { + log.Debugf("[%s] query %s: sending response to %s : %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], addr.String(), res.ArgsDebug()) + } else { + log.Debugf("[%s] %s", dht.node.id.Hex()[:8], spew.Sdump(data)) + } + encoded, err := bencode.EncodeBytes(data) + if err != nil { + return err + } + //log.Infof("Encoded: %s", string(encoded)) + + dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15)) + + _, err = dht.conn.WriteToUDP(encoded, addr) + return err +} + +func getArgs(argsInt interface{}) []string { + var args []string + if reflect.TypeOf(argsInt).Kind() == reflect.Slice { + v := reflect.ValueOf(argsInt) + for i := 0; i < v.Len(); i++ { + args = append(args, cast.ToString(v.Index(i).Interface())) + } + } + return args +} + +func argsToString(args []string) string { + argsCopy := make([]string, len(args)) + copy(argsCopy, args) + for k, v := range argsCopy { + if len(v) == nodeIDLength { + argsCopy[k] = hex.EncodeToString([]byte(v))[:8] + } + } + return strings.Join(argsCopy, ", ") +} diff --git a/dht/dht_test.go b/dht/rpc_test.go similarity index 85% rename from dht/dht_test.go rename to dht/rpc_test.go index bda9e34..fc511db 100644 --- a/dht/dht_test.go +++ b/dht/rpc_test.go @@ -2,13 +2,61 @@ package dht import ( "net" + "strconv" + "strings" "testing" "time" + "github.com/lyoshenka/bencode" log "github.com/sirupsen/logrus" - "github.com/zeebo/bencode" ) +type testUDPPacket struct { + data []byte + addr *net.UDPAddr +} + +type testUDPConn struct { + addr *net.UDPAddr + toRead chan testUDPPacket + writes chan testUDPPacket +} + +func newTestUDPConn(addr string) *testUDPConn { + parts := strings.Split(addr, ":") + if len(parts) != 2 { + panic("addr needs ip and port") + } + port, err := strconv.Atoi(parts[1]) + if err != nil { + panic(err) + } + return &testUDPConn{ + addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port}, + toRead: make(chan testUDPPacket), + writes: make(chan testUDPPacket), + } +} + +func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { + select { + case packet := <-t.toRead: + n := copy(b, packet.data) + return n, packet.addr, nil + //default: + // return 0, nil, nil + } +} + +func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + t.writes <- testUDPPacket{data: b, addr: addr} + return len(b), nil +} + +func (t testUDPConn) SetWriteDeadline(tm time.Time) error { + return nil +} + func TestPing(t *testing.T) { log.SetLevel(log.DebugLevel) dhtNodeID := newRandomBitmap() @@ -16,12 +64,15 @@ func TestPing(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) + dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) + if err != nil { + t.Fatal(err) + } dht.conn = conn dht.listen() go dht.runHandler() - messageID := newRandomBitmap().RawString() + messageID := newMessageID() data, err := bencode.EncodeBytes(map[string]interface{}{ headerTypeField: requestType, @@ -107,12 +158,16 @@ func TestStore(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) + dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) + if err != nil { + t.Fatal(err) + } + dht.conn = conn dht.listen() go dht.runHandler() - messageID := newRandomBitmap().RawString() + messageID := newMessageID() blobHashToStore := newRandomBitmap().RawString() storeRequest := Request{ @@ -178,7 +233,7 @@ func TestStore(t *testing.T) { } } - if len(dht.store.data) != 1 { + if len(dht.store.nodeIDs) != 1 { t.Error("dht store has wrong number of items") } @@ -197,7 +252,10 @@ func TestFindNode(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) + dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) + if err != nil { + t.Fatal(err) + } dht.conn = conn dht.listen() go dht.runHandler() @@ -207,10 +265,10 @@ func TestFindNode(t *testing.T) { for i := 0; i < nodesToInsert; i++ { n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} nodes = append(nodes, n) - dht.routingTable.Update(&n) + dht.rt.Update(&n) } - messageID := newRandomBitmap().RawString() + messageID := newMessageID() blobHashToFind := newRandomBitmap().RawString() request := Request{ @@ -270,7 +328,11 @@ func TestFindValueExisting(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) + dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) + if err != nil { + t.Fatal(err) + } + dht.conn = conn dht.listen() go dht.runHandler() @@ -280,16 +342,18 @@ func TestFindValueExisting(t *testing.T) { for i := 0; i < nodesToInsert; i++ { n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} nodes = append(nodes, n) - dht.routingTable.Update(&n) + dht.rt.Update(&n) } //data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565") - messageID := newRandomBitmap().RawString() + messageID := newMessageID() valueToFind := newRandomBitmap().RawString() nodeToFind := Node{id: newRandomBitmap(), ip: net.ParseIP("1.2.3.4"), port: 1286} - dht.store.Insert(valueToFind, nodeToFind) + dht.store.Upsert(valueToFind, nodeToFind) + dht.store.Upsert(valueToFind, nodeToFind) + dht.store.Upsert(valueToFind, nodeToFind) request := Request{ ID: messageID, @@ -348,7 +412,11 @@ func TestFindValueFallbackToFindNode(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) + dht, err := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) + if err != nil { + t.Fatal(err) + } + dht.conn = conn dht.listen() go dht.runHandler() @@ -358,10 +426,10 @@ func TestFindValueFallbackToFindNode(t *testing.T) { for i := 0; i < nodesToInsert; i++ { n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} nodes = append(nodes, n) - dht.routingTable.Update(&n) + dht.rt.Update(&n) } - messageID := newRandomBitmap().RawString() + messageID := newMessageID() valueToFind := newRandomBitmap().RawString() request := Request{ @@ -442,6 +510,9 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNod } else if rMessageID != messageID { t.Error("unexpected message ID") } + if len(rMessageID) != messageIDLength { + t.Errorf("message ID should be %d chars long", messageIDLength) + } } _, ok = resp[headerNodeIDField] @@ -454,6 +525,9 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNod } else if rNodeID != dhtNodeID { t.Error("unexpected node ID") } + if len(rNodeID) != nodeIDLength { + t.Errorf("node ID should be %d chars long", nodeIDLength) + } } } diff --git a/dht/store.go b/dht/store.go index 2178e60..8a34bab 100644 --- a/dht/store.go +++ b/dht/store.go @@ -4,39 +4,52 @@ import "sync" type peer struct { node Node + //, + // + // } type peerStore struct { - data map[string][]peer - lock sync.RWMutex + nodeIDs map[string]map[bitmap]bool + nodeInfo map[bitmap]peer + lock sync.RWMutex } func newPeerStore() *peerStore { return &peerStore{ - data: make(map[string][]peer), + nodeIDs: make(map[string]map[bitmap]bool), + nodeInfo: make(map[bitmap]peer), } } -func (s *peerStore) Insert(key string, node Node) { +func (s *peerStore) Upsert(key string, node Node) { s.lock.Lock() defer s.lock.Unlock() - newPeer := peer{node: node} - _, ok := s.data[key] - if !ok { - s.data[key] = []peer{newPeer} - } else { - s.data[key] = append(s.data[key], newPeer) + if _, ok := s.nodeIDs[key]; !ok { + s.nodeIDs[key] = make(map[bitmap]bool) } + s.nodeIDs[key][node.id] = true + s.nodeInfo[node.id] = peer{node: node} } func (s *peerStore) Get(key string) []Node { s.lock.RLock() defer s.lock.RUnlock() var nodes []Node - if peers, ok := s.data[key]; ok { - for _, p := range peers { - nodes = append(nodes, p.node) + if ids, ok := s.nodeIDs[key]; ok { + for id := range ids { + peer, ok := s.nodeInfo[id] + if !ok { + panic("node id in IDs list, but not in nodeInfo") + } + nodes = append(nodes, peer.node) } } return nodes } + +func (s *peerStore) CountKnownNodes() int { + s.lock.RLock() + defer s.lock.RUnlock() + return len(s.nodeInfo) +} diff --git a/dht/transaction_manager.go b/dht/transaction_manager.go new file mode 100644 index 0000000..3ed9444 --- /dev/null +++ b/dht/transaction_manager.go @@ -0,0 +1,101 @@ +package dht + +import ( + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// query represents the query data included queried node and query-formed data. +type transaction struct { + node *Node + req *Request + res chan *Response +} + +// transactionManager represents the manager of transactions. +type transactionManager struct { + lock *sync.RWMutex + transactions map[string]*transaction + dht *DHT +} + +// newTransactionManager returns new transactionManager pointer. +func newTransactionManager(dht *DHT) *transactionManager { + return &transactionManager{ + lock: &sync.RWMutex{}, + transactions: make(map[string]*transaction), + dht: dht, + } +} + +// insert adds a transaction to transactionManager. +func (tm *transactionManager) insert(trans *transaction) { + tm.lock.Lock() + defer tm.lock.Unlock() + tm.transactions[trans.req.ID] = trans +} + +// delete removes a transaction from transactionManager. +func (tm *transactionManager) delete(transID string) { + tm.lock.Lock() + defer tm.lock.Unlock() + delete(tm.transactions, transID) +} + +// find transaction for id. optionally ensure that addr matches node from transaction +func (tm *transactionManager) Find(id string, addr *net.UDPAddr) *transaction { + tm.lock.RLock() + defer tm.lock.RUnlock() + + t, ok := tm.transactions[id] + if !ok { + return nil + } else if addr != nil && t.node.Addr().String() != addr.String() { + return nil + } + + return t +} + +func (tm *transactionManager) Send(node *Node, req *Request) *Response { + if node.id.Equals(tm.dht.node.id) { + log.Error("sending query to self") + return nil + } + + req.ID = newMessageID() + trans := &transaction{ + node: node, + req: req, + res: make(chan *Response), + } + + tm.insert(trans) + defer tm.delete(trans.req.ID) + + for i := 0; i < udpRetry; i++ { + if err := send(tm.dht, trans.node.Addr(), *trans.req); err != nil { + log.Error(err) + break + } + + select { + case res := <-trans.res: + return res + case <-time.After(udpTimeout): + } + } + + tm.dht.rt.RemoveByID(trans.node.id) + return nil +} + +// Count returns the number of transactions in the manager +func (tm *transactionManager) Count() int { + tm.lock.Lock() + defer tm.lock.Unlock() + return len(tm.transactions) +}