From 3bb2d90b7b273150911cd293f2209d130800c873 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Fri, 9 Mar 2018 16:43:30 -0500 Subject: [PATCH] findvalue done --- dht/dht.go | 83 ++++++------- dht/dht_test.go | 255 ++++++++++++++++++++++++++++++++------ dht/message.go | 12 +- dht/routing_table.go | 11 +- dht/routing_table_test.go | 15 +-- dht/store.go | 12 +- main.go | 23 ---- 7 files changed, 284 insertions(+), 127 deletions(-) diff --git a/dht/dht.go b/dht/dht.go index 0267b4f..bc57393 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -20,7 +20,7 @@ const nodeIDLength = 48 // bytes. this is the constant B in the spec const bucketSize = 20 // this is the constant k in the spec const tExpire = 86400 * time.Second // the time after which a key/value pair expires; this is a time-to-live (TTL) from the original publication date -const tRefresh = 3600 * time.Second // after which an otherwise unaccessed bucket must be refreshed +const tRefresh = 3600 * time.Second // the time after which an otherwise unaccessed bucket must be refreshed const tReplicate = 3600 * time.Second // the interval between Kademlia replication events, when a node is required to publish its entire database const tRepublish = 86400 * time.Second // the time after which the original publisher must republish a key/value pair @@ -46,7 +46,7 @@ type Config struct { // NewStandardConfig returns a Config pointer with default values. func NewStandardConfig() *Config { return &Config{ - Address: ":4444", + Address: "127.0.0.1:4444", SeedNodes: []string{ "lbrynet1.lbry.io:4444", "lbrynet2.lbry.io:4444", @@ -81,13 +81,21 @@ func New(config *Config) *DHT { ip, port, err := net.SplitHostPort(config.Address) if err != nil { panic(err) + } else if ip == "" { + panic("address does not contain an IP") + } else if port == "" { + panic("address does not contain a port") } + portInt, err := cast.ToIntE(port) if err != nil { panic(err) } - node := &Node{id: id, ip: ip, port: portInt} + node := &Node{id: id, ip: net.ParseIP(ip), port: portInt} + if node.ip == nil { + panic("invalid ip") + } return &DHT{ conf: config, node: node, @@ -219,7 +227,7 @@ func handle(dht *DHT, pkt packet) { } // handleRequest handles the requests received from udp. -func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) { +func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { if request.NodeID == dht.node.id.RawString() { log.Warn("ignoring self-request") return @@ -233,7 +241,9 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) log.Errorln("blobhash is empty") return // nothing to store } - dht.store.Insert(request.StoreArgs.BlobHash, request.StoreArgs.NodeID) + // TODO: we should be sending the IP in the request, not just using the sender's IP + // TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ??? + dht.store.Insert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port}) send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse}) case findNodeMethod: log.Println("findnode") @@ -245,13 +255,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) log.Errorln("invalid node id") return } - nodeID := newBitmapFromString(request.Args[0]) - closestNodes := dht.routingTable.FindClosest(nodeID, bucketSize) - response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))} - for i, n := range closestNodes { - response.FindNodeData[i] = *n - } - send(dht, addr, response) + doFindNodes(dht, addr, request) case findValueMethod: log.Println("findvalue") if len(request.Args) < 1 { @@ -263,59 +267,52 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) return } - nodeIDs := dht.store.Get(request.Args[0]) - if len(nodeIDs) > 0 { - // return node ids + if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 { + response := Response{ID: request.ID, NodeID: dht.node.id.RawString()} + response.FindValueKey = request.Args[0] + response.FindNodeData = nodes + send(dht, addr, response) } else { - // switch to findNode + doFindNodes(dht, addr, request) } - nodeID := newBitmapFromString(request.Args[0]) - closestNodes := dht.routingTable.FindClosest(nodeID, bucketSize) - response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))} - for i, n := range closestNodes { - response.FindNodeData[i] = *n - } - send(dht, addr, response) - default: // send(dht, addr, makeError(t, protocolError, "invalid q")) log.Errorln("invalid request method") return } - node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP.String(), port: addr.Port} + node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port} dht.routingTable.Update(node) - return true +} + +func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) { + nodeID := newBitmapFromString(request.Args[0]) + closestNodes := dht.routingTable.FindClosest(nodeID, bucketSize) + if len(closestNodes) > 0 { + response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))} + for i, n := range closestNodes { + response.FindNodeData[i] = *n + } + send(dht, addr, response) + } } // handleResponse handles responses received from udp. -func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) (success bool) { +func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) { spew.Dump(response) - //switch trans.request.Method { - //case pingMethod: - //case findNodeMethod: - // target := trans.request.Args[0] - // if findOn(dht, response.FindNodeData, newBitmapFromString(target), findNodeMethod) != nil { - // return - // } - //default: - // return - //} + // TODO: find transaction by message id, pass along response - node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP.String(), port: addr.Port} + node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port} dht.routingTable.Update(node) - - return true } // handleError handles errors received from udp. -func handleError(dht *DHT, addr *net.UDPAddr, e Error) (success bool) { +func handleError(dht *DHT, addr *net.UDPAddr, e Error) { spew.Dump(e) - node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP.String(), port: addr.Port} + node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port} dht.routingTable.Update(node) - return true } // send sends data to the udp. diff --git a/dht/dht_test.go b/dht/dht_test.go index 1c0a334..bda9e34 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -1,11 +1,10 @@ package dht import ( - "encoding/hex" + "net" "testing" "time" - "github.com/davecgh/go-spew/spew" log "github.com/sirupsen/logrus" "github.com/zeebo/bencode" ) @@ -17,7 +16,7 @@ func TestPing(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()}) + dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) dht.conn = conn dht.listen() go dht.runHandler() @@ -45,8 +44,7 @@ func TestPing(t *testing.T) { var response map[string]interface{} err := bencode.DecodeBytes(resp.data, &response) if err != nil { - t.Error(err) - return + t.Fatal(err) } if len(response) != 4 { @@ -109,7 +107,7 @@ func TestStore(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()}) + dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) dht.conn = conn dht.listen() go dht.runHandler() @@ -149,8 +147,7 @@ func TestStore(t *testing.T) { data, err := bencode.EncodeBytes(storeRequest) if err != nil { - t.Error(err) - return + t.Fatal(err) } conn.toRead <- testUDPPacket{addr: conn.addr, data: data} @@ -159,13 +156,11 @@ func TestStore(t *testing.T) { var response map[string]interface{} select { case <-timer.C: - t.Error("timeout") - return + t.Fatal("timeout") case resp := <-conn.writes: err := bencode.DecodeBytes(resp.data, &response) if err != nil { - t.Error(err) - return + t.Fatal(err) } } @@ -191,7 +186,7 @@ func TestStore(t *testing.T) { if len(items) != 1 { t.Error("list created in store, but nothing in list") } - if !items[0].Equals(testNodeID) { + if !items[0].id.Equals(testNodeID) { t.Error("wrong value stored") } } @@ -202,7 +197,7 @@ func TestFindNode(t *testing.T) { conn := newTestUDPConn("127.0.0.1:21217") - dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()}) + dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) dht.conn = conn dht.listen() go dht.runHandler() @@ -210,7 +205,7 @@ func TestFindNode(t *testing.T) { nodesToInsert := 3 var nodes []Node for i := 0; i < nodesToInsert; i++ { - n := Node{id: newRandomBitmap(), ip: "127.0.0.1", port: 10000 + i} + n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} nodes = append(nodes, n) dht.routingTable.Update(&n) } @@ -227,8 +222,7 @@ func TestFindNode(t *testing.T) { data, err := bencode.EncodeBytes(request) if err != nil { - t.Error(err) - return + t.Fatal(err) } conn.toRead <- testUDPPacket{addr: conn.addr, data: data} @@ -237,13 +231,11 @@ func TestFindNode(t *testing.T) { var response map[string]interface{} select { case <-timer.C: - t.Error("timeout") - return + t.Fatal("timeout") case resp := <-conn.writes: err := bencode.DecodeBytes(resp.data, &response) if err != nil { - t.Error(err) - return + t.Fatal(err) } } @@ -251,45 +243,176 @@ func TestFindNode(t *testing.T) { _, ok := response[headerPayloadField] if !ok { - t.Error("missing payload field") - } else { - contacts, ok := response[headerPayloadField].([]interface{}) - if !ok { - t.Error("payload is not a list") - } else { - verifyContacts(t, contacts, nodes) - } + t.Fatal("missing payload field") } + + payload, ok := response[headerPayloadField].(map[string]interface{}) + if !ok { + t.Fatal("payload is not a dictionary") + } + + contactsList, ok := payload["contacts"] + if !ok { + t.Fatal("payload is missing 'contacts' key") + } + + contacts, ok := contactsList.([]interface{}) + if !ok { + t.Fatal("'contacts' is not a list") + } + + verifyContacts(t, contacts, nodes) } -func TestFindValue(t *testing.T) { +func TestFindValueExisting(t *testing.T) { dhtNodeID := newRandomBitmap() + testNodeID := newRandomBitmap() conn := newTestUDPConn("127.0.0.1:21217") - dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()}) + dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) dht.conn = conn dht.listen() go dht.runHandler() - data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565") + nodesToInsert := 3 + var nodes []Node + for i := 0; i < nodesToInsert; i++ { + n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} + nodes = append(nodes, n) + dht.routingTable.Update(&n) + } + + //data, _ := hex.DecodeString("64313a30693065313a3132303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33393a66696e6456616c7565313a346c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565") + + messageID := newRandomBitmap().RawString() + valueToFind := newRandomBitmap().RawString() + + nodeToFind := Node{id: newRandomBitmap(), ip: net.ParseIP("1.2.3.4"), port: 1286} + dht.store.Insert(valueToFind, nodeToFind) + + request := Request{ + ID: messageID, + NodeID: testNodeID.RawString(), + Method: findValueMethod, + Args: []string{valueToFind}, + } + + data, err := bencode.EncodeBytes(request) + if err != nil { + t.Fatal(err) + } conn.toRead <- testUDPPacket{addr: conn.addr, data: data} timer := time.NewTimer(3 * time.Second) + var response map[string]interface{} select { case <-timer.C: - t.Error("timeout") + t.Fatal("timeout") case resp := <-conn.writes: - var response map[string]interface{} err := bencode.DecodeBytes(resp.data, &response) if err != nil { - t.Error(err) - return + t.Fatal(err) } - - spew.Dump(response) } + + verifyResponse(t, response, messageID, dhtNodeID.RawString()) + + _, ok := response[headerPayloadField] + if !ok { + t.Fatal("missing payload field") + } + + payload, ok := response[headerPayloadField].(map[string]interface{}) + if !ok { + t.Fatal("payload is not a dictionary") + } + + compactContacts, ok := payload[valueToFind] + if !ok { + t.Fatal("payload is missing key for search value") + } + + contacts, ok := compactContacts.([]interface{}) + if !ok { + t.Fatal("search results are not a list") + } + + verifyCompactContacts(t, contacts, []Node{nodeToFind}) +} + +func TestFindValueFallbackToFindNode(t *testing.T) { + dhtNodeID := newRandomBitmap() + testNodeID := newRandomBitmap() + + conn := newTestUDPConn("127.0.0.1:21217") + + dht := New(&Config{Address: "127.0.0.1:21216", NodeID: dhtNodeID.Hex()}) + dht.conn = conn + dht.listen() + go dht.runHandler() + + nodesToInsert := 3 + var nodes []Node + for i := 0; i < nodesToInsert; i++ { + n := Node{id: newRandomBitmap(), ip: net.ParseIP("127.0.0.1"), port: 10000 + i} + nodes = append(nodes, n) + dht.routingTable.Update(&n) + } + + messageID := newRandomBitmap().RawString() + valueToFind := newRandomBitmap().RawString() + + request := Request{ + ID: messageID, + NodeID: testNodeID.RawString(), + Method: findValueMethod, + Args: []string{valueToFind}, + } + + data, err := bencode.EncodeBytes(request) + if err != nil { + t.Fatal(err) + } + + conn.toRead <- testUDPPacket{addr: conn.addr, data: data} + timer := time.NewTimer(3 * time.Second) + + var response map[string]interface{} + select { + case <-timer.C: + t.Fatal("timeout") + case resp := <-conn.writes: + err := bencode.DecodeBytes(resp.data, &response) + if err != nil { + t.Fatal(err) + } + } + + verifyResponse(t, response, messageID, dhtNodeID.RawString()) + + _, ok := response[headerPayloadField] + if !ok { + t.Fatal("missing payload field") + } + + payload, ok := response[headerPayloadField].(map[string]interface{}) + if !ok { + t.Fatal("payload is not a dictionary") + } + + contactsList, ok := payload["contacts"] + if !ok { + t.Fatal("payload is missing 'contacts' key") + } + + contacts, ok := contactsList.([]interface{}) + if !ok { + t.Fatal("'contacts' is not a list") + } + + verifyContacts(t, contacts, nodes) } func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNodeID string) { @@ -382,8 +505,8 @@ func verifyContacts(t *testing.T, contacts []interface{}, nodes []Node) { ip, ok := contact[1].(string) if !ok { t.Error("contact IP is not a string") - } else if ip != currNode.ip { - t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.ip) + } else if !currNode.ip.Equal(net.ParseIP(ip)) { + t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.ip.String()) } port, ok := contact[2].(int64) @@ -394,3 +517,55 @@ func verifyContacts(t *testing.T, contacts []interface{}, nodes []Node) { } } } + +func verifyCompactContacts(t *testing.T, contacts []interface{}, nodes []Node) { + if len(contacts) != len(nodes) { + t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes)) + return + } + + foundNodes := make(map[string]bool) + + for _, c := range contacts { + compact, ok := c.(string) + if !ok { + t.Error("contact is not a string") + return + } + + contact := Node{} + err := contact.UnmarshalCompact([]byte(compact)) + if err != nil { + t.Error(err) + return + } + + var currNode Node + currNodeFound := false + + if _, ok := foundNodes[contact.id.Hex()]; ok { + t.Errorf("contact %s appears multiple times", contact.id.Hex()) + continue + } + for _, n := range nodes { + if n.id.Equals(contact.id) { + currNode = n + currNodeFound = true + foundNodes[contact.id.Hex()] = true + break + } + } + if !currNodeFound { + t.Errorf("unexpected contact %s", contact.id.Hex()) + continue + } + + if !currNode.ip.Equal(contact.ip) { + t.Errorf("contact IP mismatch. got %s; expected %s", contact.ip.String(), currNode.ip.String()) + } + + if contact.port != currNode.port { + t.Errorf("contact port mismatch. got %d; expected %d", contact.port, currNode.port) + } + } +} diff --git a/dht/message.go b/dht/message.go index d3e9455..e7cb440 100644 --- a/dht/message.go +++ b/dht/message.go @@ -179,8 +179,18 @@ func (r Response) MarshalBencode() ([]byte, error) { } if r.Data != "" { data[headerPayloadField] = r.Data + } else if r.FindValueKey != "" { + var contacts [][]byte + for _, n := range r.FindNodeData { + compact, err := n.MarshalCompact() + if err != nil { + return nil, err + } + contacts = append(contacts, compact) + } + data[headerPayloadField] = map[string][][]byte{r.FindValueKey: contacts} } else { - data[headerPayloadField] = r.FindNodeData + data[headerPayloadField] = map[string][]Node{"contacts": r.FindNodeData} } return bencode.EncodeBytes(data) diff --git a/dht/routing_table.go b/dht/routing_table.go index 0bc2841..32b86ab 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -31,7 +31,7 @@ func (n Node) MarshalCompact() ([]byte, error) { buf.WriteByte(byte(n.port)) buf.Write(n.id[:]) - if buf.Len() != nodeIDLength+6 { + if buf.Len() != compactNodeInfoLength { return nil, errors.Err("i dont know how this happened") } @@ -39,14 +39,11 @@ func (n Node) MarshalCompact() ([]byte, error) { } func (n *Node) UnmarshalCompact(b []byte) error { - if len(b) != 6 { - return errors.Err("invalid compact ip/port") + if len(b) != compactNodeInfoLength { + return errors.Err("invalid compact length") } - copy(n.ip, b[0:4]) + n.ip = net.IPv4(b[0], b[1], b[2], b[3]) n.port = int(uint16(b[5]) | uint16(b[4])<<8) - if n.port < 0 || n.port > 65535 { - return errors.Err("invalid port") - } n.id = newBitmapFromBytes(b[6:]) return nil } diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go index 3bb5a4d..1d33016 100644 --- a/dht/routing_table_test.go +++ b/dht/routing_table_test.go @@ -2,9 +2,8 @@ package dht import ( "net" + "reflect" "testing" - - "github.com/davecgh/go-spew/spew" ) func TestRoutingTable(t *testing.T) { @@ -40,8 +39,8 @@ func TestRoutingTable(t *testing.T) { func TestCompactEncoding(t *testing.T) { n := Node{ id: newBitmapFromHex("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"), - ip: net.ParseIP("255.1.0.155"), - port: 66666, + ip: net.ParseIP("1.2.3.4"), + port: int(55<<8 + 66), } var compact []byte @@ -50,9 +49,11 @@ func TestCompactEncoding(t *testing.T) { t.Fatal(err) } - if len(compact) != nodeIDLength+6 { - t.Fatalf("got length of %d; expected %d", len(compact), nodeIDLength+6) + if len(compact) != compactNodeInfoLength { + t.Fatalf("got length of %d; expected %d", len(compact), compactNodeInfoLength) } - spew.Dump(compact) + if !reflect.DeepEqual(compact, append([]byte{1, 2, 3, 4, 55, 66}, n.id[:]...)) { + t.Errorf("compact bytes not encoded correctly") + } } diff --git a/dht/store.go b/dht/store.go index 895c61a..2178e60 100644 --- a/dht/store.go +++ b/dht/store.go @@ -3,7 +3,7 @@ package dht import "sync" type peer struct { - nodeID bitmap + node Node } type peerStore struct { @@ -17,10 +17,10 @@ func newPeerStore() *peerStore { } } -func (s *peerStore) Insert(key string, nodeId bitmap) { +func (s *peerStore) Insert(key string, node Node) { s.lock.Lock() defer s.lock.Unlock() - newPeer := peer{nodeID: nodeId} + newPeer := peer{node: node} _, ok := s.data[key] if !ok { s.data[key] = []peer{newPeer} @@ -29,13 +29,13 @@ func (s *peerStore) Insert(key string, nodeId bitmap) { } } -func (s *peerStore) Get(key string) []bitmap { +func (s *peerStore) Get(key string) []Node { s.lock.RLock() defer s.lock.RUnlock() - var nodes []bitmap + var nodes []Node if peers, ok := s.data[key]; ok { for _, p := range peers { - nodes = append(nodes, p.nodeID) + nodes = append(nodes, p.node) } } return nodes diff --git a/main.go b/main.go index 9c39b66..40ca997 100644 --- a/main.go +++ b/main.go @@ -20,31 +20,8 @@ func checkErr(err error) { func main() { rand.Seed(time.Now().UnixNano()) log.SetLevel(log.DebugLevel) - cmd.GlobalConfig = loadConfig("config.json") - cmd.Execute() - - // - //var err error - //client := reflector.Client{} - // - //log.Println("Connecting to " + reflectorAddress) - //err = client.Connect(reflectorAddress) - //checkErr(err) - // - //log.Println("Connected") - // - //defer func() { - // log.Println("Closing connection") - // client.Close() - //}() - // - //blob := make([]byte, 2*1024*1024) - //_, err = rand.Read(blob) - //checkErr(err) - //err = client.SendBlob(blob) - //checkErr(err) } func loadConfig(path string) cmd.Config {