diff --git a/dht/bitmap.go b/dht/bitmap.go new file mode 100644 index 0000000..74224b8 --- /dev/null +++ b/dht/bitmap.go @@ -0,0 +1,85 @@ +package dht + +import ( + "encoding/hex" + "math/rand" + "strconv" +) + +type bitmap [nodeIDLength]byte + +func (b bitmap) RawString() string { + return string(b[0:nodeIDLength]) +} + +func (b bitmap) Hex() string { + return hex.EncodeToString(b[0:nodeIDLength]) +} + +func (b bitmap) Equals(other bitmap) bool { + for k := range b { + if b[k] != other[k] { + return false + } + } + return true +} + +func (b bitmap) Less(other interface{}) bool { + for k := range b { + if b[k] != other.(bitmap)[k] { + return b[k] < other.(bitmap)[k] + } + } + return false +} + +func (b bitmap) Xor(other bitmap) bitmap { + var ret bitmap + for k := range b { + ret[k] = b[k] ^ other[k] + } + return ret +} + +// PrefixLen returns the number of leading 0 bits +func (b bitmap) PrefixLen() (ret int) { + for i := range b { + for j := 0; j < 8; j++ { + if (b[i]>>uint8(7-j))&0x1 != 0 { + return i*8 + j + } + } + } + return nodeIDLength*8 - 1 +} + +func newBitmapFromBytes(data []byte) bitmap { + if len(data) != nodeIDLength { + panic("invalid bitmap of length " + strconv.Itoa(len(data))) + } + + var bmp bitmap + copy(bmp[:], data) + return bmp +} + +func newBitmapFromString(data string) bitmap { + return newBitmapFromBytes([]byte(data)) +} + +func newBitmapFromHex(hexStr string) bitmap { + decoded, err := hex.DecodeString(hexStr) + if err != nil { + panic(err) + } + return newBitmapFromBytes(decoded) +} + +func newRandomBitmap() bitmap { + var id bitmap + for k := range id { + id[k] = uint8(rand.Intn(256)) + } + return id +} diff --git a/dht/bitmap_test.go b/dht/bitmap_test.go new file mode 100644 index 0000000..325f573 --- /dev/null +++ b/dht/bitmap_test.go @@ -0,0 +1,48 @@ +package dht + +import "testing" + +func TestBitmap(t *testing.T) { + a := bitmap{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + } + b := bitmap{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 46, + } + c := bitmap{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + } + + if !a.Equals(a) { + t.Error("bitmap does not equal itself") + } + if a.Equals(b) { + t.Error("bitmap equals another bitmap with different id") + } + + if !a.Xor(b).Equals(c) { + t.Error(a.Xor(b)) + } + + if c.PrefixLen() != 375 { + t.Error(c.PrefixLen()) + } + + if b.Less(a) { + t.Error("bitmap fails lessThan test") + } + + id := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + if newBitmapFromHex(id).Hex() != id { + t.Error(newBitmapFromHex(id).Hex()) + } +} diff --git a/dht/conn.go b/dht/conn.go new file mode 100644 index 0000000..3704a70 --- /dev/null +++ b/dht/conn.go @@ -0,0 +1,60 @@ +package dht + +import ( + "net" + "strconv" + "strings" + "time" +) + +type UDPConn interface { + ReadFromUDP([]byte) (int, *net.UDPAddr, error) + WriteToUDP([]byte, *net.UDPAddr) (int, error) + SetWriteDeadline(time.Time) error +} + +type testUDPPacket struct { + data []byte + addr *net.UDPAddr +} + +type testUDPConn struct { + addr *net.UDPAddr + toRead chan testUDPPacket + writes chan testUDPPacket +} + +func newTestUDPConn(addr string) *testUDPConn { + parts := strings.Split(addr, ":") + if len(parts) != 2 { + panic("addr needs ip and port") + } + port, err := strconv.Atoi(parts[1]) + if err != nil { + panic(err) + } + return &testUDPConn{ + addr: &net.UDPAddr{IP: net.IP(parts[0]), Port: port}, + toRead: make(chan testUDPPacket), + writes: make(chan testUDPPacket), + } +} + +func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { + select { + case packet := <-t.toRead: + n := copy(b, packet.data) + return n, packet.addr, nil + //default: + // return 0, nil, nil + } +} + +func (t testUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + t.writes <- testUDPPacket{data: b, addr: addr} + return len(b), nil +} + +func (t testUDPConn) SetWriteDeadline(tm time.Time) error { + return nil +} diff --git a/dht/dht.go b/dht/dht.go new file mode 100644 index 0000000..0cdedda --- /dev/null +++ b/dht/dht.go @@ -0,0 +1,350 @@ +package dht + +import ( + "encoding/hex" + "net" + "reflect" + "strings" + "time" + + "github.com/davecgh/go-spew/spew" + log "github.com/sirupsen/logrus" + "github.com/spf13/cast" + "github.com/zeebo/bencode" +) + +const network = "udp4" +const bucketSize = 20 +const numBuckets = nodeIDLength * 8 + +// packet represents the information receive from udp. +type packet struct { + data []byte + raddr *net.UDPAddr +} + +// Config represents the configure of dht. +type Config struct { + // this node's address. format is `ip:port` + Address string + // the seed nodes through which we can join in dht network + SeedNodes []string + // the hex-encoded node id for this node. if string is empty, a random id will be generated + NodeID string +} + +// NewStandardConfig returns a Config pointer with default values. +func NewStandardConfig() *Config { + return &Config{ + Address: ":4444", + SeedNodes: []string{ + "lbrynet1.lbry.io:4444", + "lbrynet2.lbry.io:4444", + "lbrynet3.lbry.io:4444", + }, + } +} + +// DHT represents a DHT node. +type DHT struct { + conf *Config + conn UDPConn + node *Node + routingTable *RoutingTable + packets chan packet +} + +// New returns a DHT pointer. If config is nil, then config will be set to the default config. +func New(config *Config) *DHT { + if config == nil { + config = NewStandardConfig() + } + + var id bitmap + if config.NodeID == "" { + id = newRandomBitmap() + } else { + id = newBitmapFromHex(config.NodeID) + } + node := &Node{id: id, addr: config.Address} + return &DHT{ + conf: config, + node: node, + routingTable: NewRoutingTable(node), + packets: make(chan packet), + } +} + +// init initializes global variables. +func (dht *DHT) init() { + log.Info("Initializing DHT on " + dht.conf.Address) + log.Infof("Node ID is %s", dht.node.id.Hex()) + listener, err := net.ListenPacket(network, dht.conf.Address) + if err != nil { + panic(err) + } + + dht.conn = listener.(*net.UDPConn) +} + +// listen receives message from udp. +func (dht *DHT) listen() { + go func() { + buf := make([]byte, 8192) + for { + n, raddr, err := dht.conn.ReadFromUDP(buf) + if err != nil { + log.Errorf("udp read error: %v", err) + continue + } else if raddr == nil { + log.Errorf("udp read with no raddr") + continue + } + dht.packets <- packet{data: buf[:n], raddr: raddr} + } + }() +} + +// join makes current node join the dht network. +func (dht *DHT) join() { + for _, addr := range dht.conf.SeedNodes { + raddr, err := net.ResolveUDPAddr(network, addr) + if err != nil { + continue + } + + _ = raddr + + // NOTE: Temporary node has NO node id. + //dht.transactionManager.findNode( + // &node{addr: raddr}, + // dht.node.id.RawString(), + //) + } +} + +func (dht *DHT) runHandler() { + var pkt packet + + for { + select { + case pkt = <-dht.packets: + handle(dht, pkt) + } + } +} + +// Run starts the dht. +func (dht *DHT) Run() { + dht.init() + dht.listen() + dht.join() + log.Info("DHT ready") + dht.runHandler() +} + +// handle handles packets received from udp. +func handle(dht *DHT, pkt packet) { + //log.Infof("Received message from %s:%s : %s\n", pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), hex.EncodeToString(pkt.data)) + + var data map[string]interface{} + err := bencode.DecodeBytes(pkt.data, &data) + if err != nil { + log.Errorf("Error decoding data: %s\n%s", err, pkt.data) + return + } + + msgType, ok := data[headerTypeField] + if !ok { + log.Errorf("Decoded data has no message type: %s", data) + return + } + + switch msgType.(int64) { + case requestType: + request := Request{ + ID: data[headerMessageIDField].(string), + NodeID: data[headerNodeIDField].(string), + Method: data[headerPayloadField].(string), + Args: getArgs(data[headerArgsField]), + } + log.Infof("%s: Received from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args)) + handleRequest(dht, pkt.raddr, request) + + case responseType: + response := Response{ + ID: data[headerMessageIDField].(string), + NodeID: data[headerNodeIDField].(string), + } + + if reflect.TypeOf(data[headerPayloadField]).Kind() == reflect.String { + response.Data = data[headerPayloadField].(string) + } else { + response.FindNodeData = getFindNodeResponse(data[headerPayloadField]) + } + + handleResponse(dht, pkt.raddr, response) + + case errorType: + e := Error{ + ID: data[headerMessageIDField].(string), + NodeID: data[headerNodeIDField].(string), + ExceptionType: data[headerPayloadField].(string), + Response: getArgs(data[headerArgsField]), + } + handleError(dht, pkt.raddr, e) + + default: + log.Errorf("Invalid message type: %s", msgType) + return + } +} + +// handleRequest handles the requests received from udp. +func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) { + log.Infoln("handling request") + if request.NodeID == dht.node.id.RawString() { + log.Warn("ignoring self-request") + return + } + + switch request.Method { + case pingMethod: + log.Println("ping") + send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: "pong"}) + case storeMethod: + log.Println("store") + case findNodeMethod: + log.Println("findnode") + //if len(request.Args) < 1 { + // send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"No target"}}) + // return + //} + // + //target := request.Args[0] + //if len(target) != nodeIDLength { + // send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"Invalid target"}}) + // return + //} + // + //nodes := []findNodeDatum{} + //targetID := newBitmapFromString(target) + // + //no, _ := dht.routingTable.GetNodeKBucktByID(targetID) + //if no != nil { + // nodes = []findNodeDatum{{ID: no.id.RawString(), IP: no.addr.IP.String(), Port: no.addr.Port}} + //} else { + // neighbors := dht.routingTable.GetNeighbors(targetID, dht.K) + // for _, n := range neighbors { + // nodes = append(nodes, findNodeDatum{ID: n.id.RawString(), IP: n.addr.IP.String(), Port: n.addr.Port}) + // } + //} + // + //send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: nodes}) + + default: + // send(dht, addr, makeError(t, protocolError, "invalid q")) + return + } + + node := &Node{id: newBitmapFromString(request.NodeID), addr: addr.String()} + dht.routingTable.Update(node) + return true +} + +// handleResponse handles responses received from udp. +func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) (success bool) { + spew.Dump(response) + + //switch trans.request.Method { + //case pingMethod: + //case findNodeMethod: + // target := trans.request.Args[0] + // if findOn(dht, response.FindNodeData, newBitmapFromString(target), findNodeMethod) != nil { + // return + // } + //default: + // return + //} + + node := &Node{id: newBitmapFromString(response.NodeID), addr: addr.String()} + dht.routingTable.Update(node) + + return true +} + +// handleError handles errors received from udp. +func handleError(dht *DHT, addr *net.UDPAddr, e Error) (success bool) { + spew.Dump(e) + return true +} + +// send sends data to the udp. +func send(dht *DHT, addr *net.UDPAddr, data Message) error { + if req, ok := data.(Request); ok { + log.Infof("%s: Sending %s(%s)", hex.EncodeToString([]byte(req.NodeID))[:8], req.Method, argsToString(req.Args)) + } else { + log.Infof("%s: Sending %s", data.GetID(), spew.Sdump(data)) + } + encoded, err := data.Encode() + if err != nil { + return err + } + //log.Infof("Encoded: %s", string(encoded)) + + dht.conn.SetWriteDeadline(time.Now().Add(time.Second * 15)) + + _, err = dht.conn.WriteToUDP(encoded, addr) + return err +} + +func getFindNodeResponse(i interface{}) (data []findNodeDatum) { + if reflect.TypeOf(i).Kind() != reflect.Slice { + return + } + + v := reflect.ValueOf(i) + for i := 0; i < v.Len(); i++ { + if v.Index(i).Kind() != reflect.Interface { + continue + } + + contact := v.Index(i).Elem() + if contact.Type().Kind() != reflect.Slice || contact.Len() != 3 { + continue + } + + if contact.Index(0).Elem().Kind() != reflect.String || + contact.Index(1).Elem().Kind() != reflect.String || + !(contact.Index(2).Elem().Kind() == reflect.Int64 || + contact.Index(2).Elem().Kind() == reflect.Int) { + continue + } + + data = append(data, findNodeDatum{ + ID: contact.Index(0).Elem().String(), + IP: contact.Index(1).Elem().String(), + Port: int(contact.Index(2).Elem().Int()), + }) + } + return +} + +func getArgs(argsInt interface{}) (args []string) { + if reflect.TypeOf(argsInt).Kind() == reflect.Slice { + v := reflect.ValueOf(argsInt) + for i := 0; i < v.Len(); i++ { + args = append(args, cast.ToString(v.Index(i).Interface())) + } + } + return +} + +func argsToString(args []string) string { + for k, v := range args { + if len(v) == nodeIDLength { + args[k] = hex.EncodeToString([]byte(v))[:8] + } + } + return strings.Join(args, ", ") +} diff --git a/dht/dht_test.go b/dht/dht_test.go new file mode 100644 index 0000000..8d5d9a8 --- /dev/null +++ b/dht/dht_test.go @@ -0,0 +1,193 @@ +package dht + +import ( + "testing" + "time" + + "github.com/zeebo/bencode" +) + +func TestPing(t *testing.T) { + dhtNodeID := newRandomBitmap() + testNodeID := newRandomBitmap() + + conn := newTestUDPConn("127.0.0.1:21217") + + dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()}) + dht.conn = conn + dht.listen() + go dht.runHandler() + + messageID := newRandomBitmap().RawString() + + data, err := bencode.EncodeBytes(map[string]interface{}{ + headerTypeField: requestType, + headerMessageIDField: messageID, + headerNodeIDField: testNodeID.RawString(), + headerPayloadField: "ping", + headerArgsField: []string{}, + }) + if err != nil { + panic(err) + } + + conn.toRead <- testUDPPacket{addr: conn.addr, data: data} + timer := time.NewTimer(3 * time.Second) + + select { + case <-timer.C: + t.Error("timeout") + case resp := <-conn.writes: + var response map[string]interface{} + err := bencode.DecodeBytes(resp.data, &response) + if err != nil { + t.Error(err) + return + } + + if len(response) != 4 { + t.Errorf("expected 4 response fields, got %d", len(response)) + } + + _, ok := response[headerTypeField] + if !ok { + t.Error("missing type field") + } else { + rType, ok := response[headerTypeField].(int64) + if !ok { + t.Error("type is not an integer") + } else if rType != responseType { + t.Error("unexpected response type") + } + } + + _, ok = response[headerMessageIDField] + if !ok { + t.Error("missing message id field") + } else { + rMessageID, ok := response[headerMessageIDField].(string) + if !ok { + t.Error("message ID is not a string") + } else if rMessageID != messageID { + t.Error("unexpected message ID") + } + } + + _, ok = response[headerNodeIDField] + if !ok { + t.Error("missing node id field") + } else { + rNodeID, ok := response[headerNodeIDField].(string) + if !ok { + t.Error("node ID is not a string") + } else if rNodeID != dhtNodeID.RawString() { + t.Error("unexpected node ID") + } + } + + _, ok = response[headerPayloadField] + if !ok { + t.Error("missing payload field") + } else { + rNodeID, ok := response[headerPayloadField].(string) + if !ok { + t.Error("payload is not a string") + } else if rNodeID != pingSuccessResponse { + t.Error("did not pong") + } + } + } +} + +func TestStore(t *testing.T) { + dhtNodeID := newRandomBitmap() + testNodeID := newRandomBitmap() + + conn := newTestUDPConn("127.0.0.1:21217") + + dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()}) + dht.conn = conn + dht.listen() + go dht.runHandler() + + messageID := newRandomBitmap().RawString() + idToStore := newRandomBitmap().RawString() + + data, err := bencode.EncodeBytes(map[string]interface{}{ + headerTypeField: requestType, + headerMessageIDField: messageID, + headerNodeIDField: testNodeID.RawString(), + headerPayloadField: "store", + headerArgsField: []string{idToStore}, + }) + if err != nil { + panic(err) + } + + conn.toRead <- testUDPPacket{addr: conn.addr, data: data} + timer := time.NewTimer(3 * time.Second) + + select { + case <-timer.C: + t.Error("timeout") + case resp := <-conn.writes: + var response map[string]interface{} + err := bencode.DecodeBytes(resp.data, &response) + if err != nil { + t.Error(err) + return + } + + if len(response) != 4 { + t.Errorf("expected 4 response fields, got %d", len(response)) + } + + _, ok := response[headerTypeField] + if !ok { + t.Error("missing type field") + } else { + rType, ok := response[headerTypeField].(int64) + if !ok { + t.Error("type is not an integer") + } else if rType != responseType { + t.Error("unexpected response type") + } + } + + _, ok = response[headerMessageIDField] + if !ok { + t.Error("missing message id field") + } else { + rMessageID, ok := response[headerMessageIDField].(string) + if !ok { + t.Error("message ID is not a string") + } else if rMessageID != messageID { + t.Error("unexpected message ID") + } + } + + _, ok = response[headerNodeIDField] + if !ok { + t.Error("missing node id field") + } else { + rNodeID, ok := response[headerNodeIDField].(string) + if !ok { + t.Error("node ID is not a string") + } else if rNodeID != dhtNodeID.RawString() { + t.Error("unexpected node ID") + } + } + + _, ok = response[headerPayloadField] + if !ok { + t.Error("missing payload field") + } else { + rNodeID, ok := response[headerPayloadField].(string) + if !ok { + t.Error("payload is not a string") + } else if rNodeID != storeSuccessResponse { + t.Error("did not return OK") + } + } + } +} diff --git a/dht/messages.go b/dht/messages.go new file mode 100644 index 0000000..615d8b0 --- /dev/null +++ b/dht/messages.go @@ -0,0 +1,103 @@ +package dht + +import "github.com/zeebo/bencode" + +const ( + pingMethod = "ping" + storeMethod = "store" + findNodeMethod = "findNode" + findValueMethod = "findValue" +) + +const ( + pingSuccessResponse = "pong" + storeSuccessResponse = "OK" +) + +const ( + requestType = 0 + responseType = 1 + errorType = 2 +) + +const ( + // these are strings because bencode requires bytestring keys + headerTypeField = "0" + headerMessageIDField = "1" + headerNodeIDField = "2" + headerPayloadField = "3" + headerArgsField = "4" +) + +type Message interface { + GetID() string + Encode() ([]byte, error) +} + +type Request struct { + ID string + NodeID string + Method string + Args []string +} + +func (r Request) GetID() string { return r.ID } +func (r Request) Encode() ([]byte, error) { + return bencode.EncodeBytes(map[string]interface{}{ + headerTypeField: requestType, + headerMessageIDField: r.ID, + headerNodeIDField: r.NodeID, + headerPayloadField: r.Method, + headerArgsField: r.Args, + }) +} + +type findNodeDatum struct { + ID string + IP string + Port int +} +type Response struct { + ID string + NodeID string + Data string + FindNodeData []findNodeDatum +} + +func (r Response) GetID() string { return r.ID } +func (r Response) Encode() ([]byte, error) { + data := map[string]interface{}{ + headerTypeField: responseType, + headerMessageIDField: r.ID, + headerNodeIDField: r.NodeID, + } + if r.Data != "" { + data[headerPayloadField] = r.Data + } else { + var nodes []interface{} + for _, n := range r.FindNodeData { + nodes = append(nodes, []interface{}{n.ID, n.IP, n.Port}) + } + data[headerPayloadField] = nodes + } + + return bencode.EncodeBytes(data) +} + +type Error struct { + ID string + NodeID string + Response []string + ExceptionType string +} + +func (e Error) GetID() string { return e.ID } +func (e Error) Encode() ([]byte, error) { + return bencode.EncodeBytes(map[string]interface{}{ + headerTypeField: errorType, + headerMessageIDField: e.ID, + headerNodeIDField: e.NodeID, + headerPayloadField: e.ExceptionType, + headerArgsField: e.Response, + }) +} diff --git a/dht/node.go b/dht/node.go new file mode 100644 index 0000000..9df728b --- /dev/null +++ b/dht/node.go @@ -0,0 +1,20 @@ +package dht + +const nodeIDLength = 48 // bytes +const compactNodeInfoLength = nodeIDLength + 6 + +type Node struct { + id bitmap + addr string +} + +type SortedNode struct { + node *Node + sortKey bitmap +} + +type byXorDistance []*SortedNode + +func (a byXorDistance) Len() int { return len(a) } +func (a byXorDistance) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a byXorDistance) Less(i, j int) bool { return a[i].sortKey.Less(a[j].sortKey) } diff --git a/dht/routing_table.go b/dht/routing_table.go new file mode 100644 index 0000000..6daa92b --- /dev/null +++ b/dht/routing_table.go @@ -0,0 +1,79 @@ +package dht + +import ( + "container/list" + "sort" +) + +type RoutingTable struct { + node Node + buckets [numBuckets]*list.List +} + +func NewRoutingTable(node *Node) *RoutingTable { + var rt RoutingTable + for i := range rt.buckets { + rt.buckets[i] = list.New() + } + rt.node = *node + return &rt +} + +func (rt *RoutingTable) Update(node *Node) { + prefixLength := node.id.Xor(rt.node.id).PrefixLen() + bucket := rt.buckets[prefixLength] + element := findInList(bucket, rt.node.id) + if element == nil { + if bucket.Len() <= bucketSize { + bucket.PushBack(node) + } + // TODO: Handle insertion when the list is full by evicting old elements if + // they don't respond to a ping. + } else { + bucket.MoveToBack(element) + } +} + +func (rt *RoutingTable) FindClosest(target bitmap, count int) []*Node { + toSort := []*SortedNode{} + + prefixLength := target.Xor(rt.node.id).PrefixLen() + bucket := rt.buckets[prefixLength] + appendNodes(bucket.Front(), nil, &toSort, target) + + for i := 1; (prefixLength-i >= 0 || prefixLength+i < nodeIDLength*8) && len(toSort) < count; i++ { + if prefixLength-i >= 0 { + bucket = rt.buckets[prefixLength-i] + appendNodes(bucket.Front(), nil, &toSort, target) + } + if prefixLength+i < nodeIDLength*8 { + bucket = rt.buckets[prefixLength+i] + appendNodes(bucket.Front(), nil, &toSort, target) + } + } + + sort.Sort(byXorDistance(toSort)) + + nodes := []*Node{} + for _, c := range toSort { + nodes = append(nodes, c.node) + } + + return nodes +} + +func findInList(bucket *list.List, value bitmap) *list.Element { + for curr := bucket.Front(); curr != nil; curr = curr.Next() { + if curr.Value.(*Node).id.Equals(value) { + return curr + } + } + return nil +} + +func appendNodes(start, end *list.Element, nodes *[]*SortedNode, target bitmap) { + for curr := start; curr != end; curr = curr.Next() { + node := curr.Value.(*Node) + *nodes = append(*nodes, &SortedNode{node, node.id.Xor(target)}) + } +} diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go new file mode 100644 index 0000000..d9695bd --- /dev/null +++ b/dht/routing_table_test.go @@ -0,0 +1,33 @@ +package dht + +import "testing" + +func TestRoutingTable(t *testing.T) { + n1 := newBitmapFromHex("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") + n2 := newBitmapFromHex("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") + n3 := newBitmapFromHex("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") + rt := NewRoutingTable(&Node{n1, "localhost:8000"}) + rt.Update(&Node{n2, "localhost:8001"}) + rt.Update(&Node{n3, "localhost:8002"}) + + contacts := rt.FindClosest(newBitmapFromHex("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1) + if len(contacts) != 1 { + t.Fail() + return + } + if !contacts[0].id.Equals(n3) { + t.Error(contacts[0]) + } + + contacts = rt.FindClosest(n2, 10) + if len(contacts) != 2 { + t.Error(len(contacts)) + return + } + if !contacts[0].id.Equals(n2) { + t.Error(contacts[0]) + } + if !contacts[1].id.Equals(n3) { + t.Error(contacts[1]) + } +}