diff --git a/dht/bits/bitmap.go b/dht/bits/bitmap.go index 0605fb6..f0d8a26 100644 --- a/dht/bits/bitmap.go +++ b/dht/bits/bitmap.go @@ -344,6 +344,11 @@ func MaxP() Bitmap { return FromHexP(strings.Repeat("f", NumBytes*2)) } +// Min returns a bitmap with all bits set to 0 +func MinP() Bitmap { + return FromHexP(strings.Repeat("0", NumBytes*2)) +} + // Rand generates a cryptographically random bitmap with the confines of the parameters specified. func Rand() Bitmap { var id Bitmap diff --git a/dht/routing_table.go b/dht/routing_table.go index 9a109ea..dd0f69c 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -149,27 +149,29 @@ func (b *bucket) NeedsRefresh(refreshInterval time.Duration) bool { type routingTable struct { id bits.Bitmap buckets []bucket + lock *sync.RWMutex } func newRoutingTable(id bits.Bitmap) *routingTable { var rt routingTable rt.id = id + rt.lock = &sync.RWMutex{} rt.reset() return &rt } func (rt *routingTable) reset() { - start := big.NewInt(0) - end := big.NewInt(1) - end.Lsh(end, bits.NumBits) - end.Sub(end, big.NewInt(1)) + rt.Lock() + defer rt.Unlock() + newBucketLock := &sync.RWMutex{} + newBucketLock.Lock() rt.buckets = []bucket{} rt.buckets = append(rt.buckets, bucket{ peers: make([]peer, 0, bucketSize), - lock: &sync.RWMutex{}, + lock: newBucketLock, bucketRange: bits.Range{ - Start: bits.FromBigP(start), - End: bits.FromBigP(end), + Start: bits.MinP(), + End: bits.MaxP(), }, }) } @@ -207,9 +209,33 @@ func (rt *routingTable) Fail(c Contact) { rt.bucketFor(c.ID).FailContact(c.ID) } +func (rt *routingTable) getClosestToUs(limit int) []Contact { + contacts := []Contact{} + toSort := []sortedContact{} + rt.lock.RLock() + defer rt.lock.RUnlock() + for _, bucket := range rt.buckets { + toSort = []sortedContact{} + toSort = appendContacts(toSort, bucket, rt.id) + sort.Sort(byXorDistance(toSort)) + for _, sorted := range toSort { + contacts = append(contacts, sorted.contact) + if len(contacts) >= limit { + break + } + } + } + return contacts +} + // GetClosest returns the closest `limit` contacts from the routing table // It marks each bucket it accesses as having been accessed func (rt *routingTable) GetClosest(target bits.Bitmap, limit int) []Contact { + if target == rt.id { + return rt.getClosestToUs(limit) + } + rt.lock.RLock() + defer rt.lock.RUnlock() toSort := []sortedContact{} for _, b := range rt.buckets { toSort = appendContacts(toSort, b, target) @@ -235,15 +261,26 @@ func appendContacts(contacts []sortedContact, b bucket, target bits.Bitmap) []so // Count returns the number of contacts in the routing table func (rt *routingTable) Count() int { count := 0 + rt.lock.RLock() + defer rt.lock.RUnlock() for _, bucket := range rt.buckets { count += bucket.Len() } return count } +// Len returns the number of buckets in the routing table +func (rt *routingTable) Len() int { + rt.lock.RLock() + defer rt.lock.RUnlock() + return len(rt.buckets) +} + // BucketRanges returns a slice of ranges, where the `start` of each range is the smallest id that can // go in that bucket, and the `end` is the largest id func (rt *routingTable) BucketRanges() []bits.Range { + rt.lock.RLock() + defer rt.lock.RUnlock() ranges := make([]bits.Range, len(rt.buckets)) for i, b := range rt.buckets { ranges[i] = b.bucketRange @@ -252,6 +289,8 @@ func (rt *routingTable) BucketRanges() []bits.Range { } func (rt *routingTable) bucketNumFor(target bits.Bitmap) int { + rt.lock.RLock() + defer rt.lock.RUnlock() if rt.id.Equals(target) { panic("routing table does not have a bucket for its own id") } @@ -265,13 +304,16 @@ func (rt *routingTable) bucketNumFor(target bits.Bitmap) int { } func (rt *routingTable) bucketFor(target bits.Bitmap) *bucket { - return &rt.buckets[rt.bucketNumFor(target)] + bucketIndex := rt.bucketNumFor(target) + rt.lock.RLock() + defer rt.lock.RUnlock() + return &rt.buckets[bucketIndex] } func (rt *routingTable) shouldSplit(target bits.Bitmap) bool { - bucketIndex := rt.bucketNumFor(target) - if len(rt.buckets[bucketIndex].peers) >= bucketSize { - if bucketIndex == 0 { // this is the bucket covering our node id + b := rt.bucketFor(target) + if b.Len() >= bucketSize { + if b.bucketRange.Start.Equals(bits.MinP()) { // this is the bucket covering our node id return true } kClosest := rt.GetClosest(rt.id, bucketSize) @@ -285,20 +327,35 @@ func (rt *routingTable) shouldSplit(target bits.Bitmap) bool { func (rt *routingTable) insertContact(c Contact) { bucketIndex := rt.bucketNumFor(c.ID) - peersInBucket := int(len(rt.buckets[bucketIndex].peers)) + peersInBucket :=rt.buckets[bucketIndex].Len() if peersInBucket < bucketSize { rt.buckets[rt.bucketNumFor(c.ID)].UpdateContact(c, true) } else if peersInBucket >= bucketSize && rt.shouldSplit(c.ID) { rt.splitBucket(bucketIndex) rt.insertContact(c) + rt.popEmptyBuckets() + } +} + +func (rt * routingTable) Lock() { + rt.lock.Lock() + for _, buk := range rt.buckets { + buk.lock.Lock() + } +} + +func (rt * routingTable) Unlock() { + rt.lock.Unlock() + for _, buk := range rt.buckets { + buk.lock.Unlock() } - rt.popEmptyBuckets() } func (rt *routingTable) splitBucket(bucketIndex int) { + rt.Lock() + defer rt.Unlock() b := rt.buckets[bucketIndex] - min := b.bucketRange.Start.Big() max := b.bucketRange.End.Big() midpoint := &big.Int{} @@ -310,7 +367,6 @@ func (rt *routingTable) splitBucket(bucketIndex int) { midpointPlusOne.Add(midpoint, big.NewInt(1)) first_half := rt.buckets[:bucketIndex+1] - second_half := []bucket{} for i := bucketIndex + 1; i < len(rt.buckets); i++ { second_half = append(second_half, rt.buckets[i]) @@ -321,20 +377,22 @@ func (rt *routingTable) splitBucket(bucketIndex int) { b.peers = []peer{} rt.buckets = []bucket{} - for _, i := range first_half { - rt.buckets = append(rt.buckets, i) + for _, buk := range first_half { + rt.buckets = append(rt.buckets, buk) } + newBucketLock := &sync.RWMutex{} + newBucketLock.Lock() // will be unlocked by the deferred rt.Unlock() newBucket := bucket{ peers: make([]peer, 0, bucketSize), - lock: &sync.RWMutex{}, + lock: newBucketLock, bucketRange: bits.Range{ Start: bits.FromBigP(midpointPlusOne), End: bits.FromBigP(max), }, } rt.buckets = append(rt.buckets, newBucket) - for _, i := range second_half { - rt.buckets = append(rt.buckets, i) + for _, buk := range second_half { + rt.buckets = append(rt.buckets, buk) } // re-size the bucket to be split rt.buckets[bucketIndex].bucketRange.Start = bits.FromBigP(min) @@ -393,6 +451,9 @@ func (rt *routingTable) popNextEmptyBucket() bool { } func (rt *routingTable) popEmptyBuckets() { + rt.Lock() + defer rt.Unlock() + if len(rt.buckets) > 1 { popBuckets := rt.popNextEmptyBucket() for popBuckets == true { diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go index 69cc8f4..5ab3493 100644 --- a/dht/routing_table_test.go +++ b/dht/routing_table_test.go @@ -11,30 +11,6 @@ import ( "github.com/sebdah/goldie" ) -func TestRoutingTable_bucketFor(t *testing.T) { - rt := newRoutingTable(bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")) - var tests = []struct { - id bits.Bitmap - expected int - }{ - {bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), 0}, - {bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002"), 1}, - {bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003"), 1}, - {bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004"), 2}, - {bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005"), 2}, - {bits.FromHexP("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f"), 3}, - {bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010"), 4}, - {bits.FromHexP("F00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 383}, - {bits.FromHexP("F0000000000000000000000000000000F0000000000000000000000000F0000000000000000000000000000000000000"), 383}, - } - - for _, tt := range tests { - bucket := rt.bucketNumFor(tt.id) - if bucket != tt.expected { - t.Errorf("bucketFor(%s, %s) => %d, want %d", tt.id.Hex(), rt.id.Hex(), bucket, tt.expected) - } - } -} func checkBucketCount(rt *routingTable, t *testing.T, correctSize, correctCount, testCaseIndex int) { if len(rt.buckets) != correctSize { @@ -121,6 +97,26 @@ func TestSplitBuckets(t *testing.T) { checkBucketCount(rt, t, testCase.expectedBucketCount, testCase.expectedTotalContacts, i) checkRangeContinuity(rt, t) } + + var testRanges = []struct { + id bits.Bitmap + expected int + }{ + {bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"), 0}, + {bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005"), 0}, + {bits.FromHexP("200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010"), 1}, + {bits.FromHexP("380000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 2}, + {bits.FromHexP("F00000000000000000000000000000000000000000000000000F00000000000000000000000000000000000000000000"), 3}, + {bits.FromHexP("F0000000000000000000000000000000F0000000000000000000000000F0000000000000000000000000000000000000"), 3}, + } + + for _, tt := range testRanges { + bucket := rt.bucketNumFor(tt.id) + if bucket != tt.expected { + t.Errorf("bucketFor(%s, %s) => %d, want %d", tt.id.Hex(), rt.id.Hex(), bucket, tt.expected) + } + } + rt.printBucketInfo() } @@ -208,32 +204,6 @@ func TestRoutingTable_MoveToBack(t *testing.T) { } } -func TestRoutingTable_InitialBucketRange(t *testing.T) { - id := bits.FromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41") - rt := newRoutingTable(id) - ranges := rt.BucketRanges() - bucketRange := ranges[0] - if len(ranges) != 1 { - t.Error("there should only be one bucket") - } - if !ranges[0].Start.Equals(bits.FromHexP("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")) { - t.Error("bucket does not cover the lower keyspace") - } - if !ranges[0].End.Equals(bits.FromHexP("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")) { - t.Error("bucket does not cover the upper keyspace") - } - found := 0 - for i := 0; i < 1000; i++ { - randID := bits.Rand() - if bucketRange.Start.Cmp(randID) <= 0 && bucketRange.End.Cmp(randID) >= 0 { - found += 1 - } - } - if found != 1000 { - t.Errorf("%d did not appear in any bucket", found) - } - log.Println(rt.Count()) -} func TestRoutingTable_Save(t *testing.T) { id := bits.FromHexP("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41")