diff --git a/addrmgr/addrmanager.go b/addrmgr/addrmanager.go index fc693068..a8a8fb33 100644 --- a/addrmgr/addrmanager.go +++ b/addrmgr/addrmanager.go @@ -46,6 +46,7 @@ type AddrManager struct { nNew int lamtx sync.Mutex localAddresses map[string]*localAddress + version int } type serializedKnownAddress struct { @@ -55,6 +56,8 @@ type serializedKnownAddress struct { TimeStamp int64 LastAttempt int64 LastSuccess int64 + Services wire.ServiceFlag + SrcServices wire.ServiceFlag // no refcount or tried, that is available from context. } @@ -155,7 +158,7 @@ const ( getAddrPercent = 23 // serialisationVersion is the current version of the on-disk format. - serialisationVersion = 1 + serialisationVersion = 2 ) // updateAddress is a helper function to either update an address already known @@ -362,7 +365,7 @@ func (a *AddrManager) savePeers() { // First we make a serialisable datastructure so we can encode it to // json. sam := new(serializedAddrManager) - sam.Version = serialisationVersion + sam.Version = a.version copy(sam.Key[:], a.key[:]) sam.Addresses = make([]*serializedKnownAddress, len(a.addrIndex)) @@ -375,6 +378,10 @@ func (a *AddrManager) savePeers() { ska.Attempts = v.attempts ska.LastAttempt = v.lastattempt.Unix() ska.LastSuccess = v.lastsuccess.Unix() + if a.version > 1 { + ska.Services = v.na.Services + ska.SrcServices = v.srcAddr.Services + } // Tried and refs are implicit in the rest of the structure // and will be worked out from context on unserialisation. sam.Addresses[i] = ska @@ -451,24 +458,43 @@ func (a *AddrManager) deserializePeers(filePath string) error { return fmt.Errorf("error reading %s: %v", filePath, err) } - if sam.Version != serialisationVersion { + // Since decoding JSON is backwards compatible (i.e., only decodes + // fields it understands), we'll only return an error upon seeing a + // version past our latest supported version. + if sam.Version > serialisationVersion { return fmt.Errorf("unknown version %v in serialized "+ "addrmanager", sam.Version) } + copy(a.key[:], sam.Key[:]) for _, v := range sam.Addresses { ka := new(KnownAddress) - ka.na, err = a.DeserializeNetAddress(v.Addr) + + // The first version of the serialized address manager was not + // aware of the service bits associated with this address, so + // we'll assign a default of SFNodeNetwork to it. + if sam.Version == 1 { + v.Services = wire.SFNodeNetwork + } + ka.na, err = a.DeserializeNetAddress(v.Addr, v.Services) if err != nil { return fmt.Errorf("failed to deserialize netaddress "+ "%s: %v", v.Addr, err) } - ka.srcAddr, err = a.DeserializeNetAddress(v.Src) + + // The first version of the serialized address manager was not + // aware of the service bits associated with the source address, + // so we'll assign a default of SFNodeNetwork to it. + if sam.Version == 1 { + v.SrcServices = wire.SFNodeNetwork + } + ka.srcAddr, err = a.DeserializeNetAddress(v.Src, v.SrcServices) if err != nil { return fmt.Errorf("failed to deserialize netaddress "+ "%s: %v", v.Src, err) } + ka.attempts = v.Attempts ka.lastattempt = time.Unix(v.LastAttempt, 0) ka.lastsuccess = time.Unix(v.LastSuccess, 0) @@ -520,8 +546,10 @@ func (a *AddrManager) deserializePeers(filePath string) error { return nil } -// DeserializeNetAddress converts a given address string to a *wire.NetAddress -func (a *AddrManager) DeserializeNetAddress(addr string) (*wire.NetAddress, error) { +// DeserializeNetAddress converts a given address string to a *wire.NetAddress. +func (a *AddrManager) DeserializeNetAddress(addr string, + services wire.ServiceFlag) (*wire.NetAddress, error) { + host, portStr, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -531,7 +559,7 @@ func (a *AddrManager) DeserializeNetAddress(addr string) (*wire.NetAddress, erro return nil, err } - return a.HostToNetAddress(host, uint16(port), wire.SFNodeNetwork) + return a.HostToNetAddress(host, uint16(port), services) } // Start begins the core address handler which manages a pool of known @@ -635,21 +663,9 @@ func (a *AddrManager) NeedMoreAddresses() bool { // AddressCache returns the current address cache. It must be treated as // read-only (but since it is a copy now, this is not as dangerous). func (a *AddrManager) AddressCache() []*wire.NetAddress { - a.mtx.Lock() - defer a.mtx.Unlock() + allAddr := a.getAddresses() - addrIndexLen := len(a.addrIndex) - if addrIndexLen == 0 { - return nil - } - - allAddr := make([]*wire.NetAddress, 0, addrIndexLen) - // Iteration order is undefined here, but we randomise it anyway. - for _, v := range a.addrIndex { - allAddr = append(allAddr, v.na) - } - - numAddresses := addrIndexLen * getAddrPercent / 100 + numAddresses := len(allAddr) * getAddrPercent / 100 if numAddresses > getAddrMax { numAddresses = getAddrMax } @@ -658,7 +674,7 @@ func (a *AddrManager) AddressCache() []*wire.NetAddress { // `numAddresses' since we are throwing the rest. for i := 0; i < numAddresses; i++ { // pick a number between current index and the end - j := rand.Intn(addrIndexLen-i) + i + j := rand.Intn(len(allAddr)-i) + i allAddr[i], allAddr[j] = allAddr[j], allAddr[i] } @@ -666,6 +682,25 @@ func (a *AddrManager) AddressCache() []*wire.NetAddress { return allAddr[0:numAddresses] } +// getAddresses returns all of the addresses currently found within the +// manager's address cache. +func (a *AddrManager) getAddresses() []*wire.NetAddress { + a.mtx.Lock() + defer a.mtx.Unlock() + + addrIndexLen := len(a.addrIndex) + if addrIndexLen == 0 { + return nil + } + + addrs := make([]*wire.NetAddress, 0, addrIndexLen) + for _, v := range a.addrIndex { + addrs = append(addrs, v.na) + } + + return addrs +} + // reset resets the address manager by reinitialising the random source // and allocating fresh empty bucket storage. func (a *AddrManager) reset() { @@ -1109,6 +1144,7 @@ func New(dataDir string, lookupFunc func(string) ([]net.IP, error)) *AddrManager rand: rand.New(rand.NewSource(time.Now().UnixNano())), quit: make(chan struct{}), localAddresses: make(map[string]*localAddress), + version: serialisationVersion, } am.reset() return &am diff --git a/addrmgr/addrmanager_internal_test.go b/addrmgr/addrmanager_internal_test.go new file mode 100644 index 00000000..1c19dceb --- /dev/null +++ b/addrmgr/addrmanager_internal_test.go @@ -0,0 +1,194 @@ +package addrmgr + +import ( + "io/ioutil" + "math/rand" + "net" + "os" + "testing" + + "github.com/btcsuite/btcd/wire" +) + +// randAddr generates a *wire.NetAddress backed by a random IPv4/IPv6 address. +func randAddr(t *testing.T) *wire.NetAddress { + t.Helper() + + ipv4 := rand.Intn(2) == 0 + var ip net.IP + if ipv4 { + var b [4]byte + if _, err := rand.Read(b[:]); err != nil { + t.Fatal(err) + } + ip = b[:] + } else { + var b [16]byte + if _, err := rand.Read(b[:]); err != nil { + t.Fatal(err) + } + ip = b[:] + } + + return &wire.NetAddress{ + Services: wire.ServiceFlag(rand.Uint64()), + IP: ip, + Port: uint16(rand.Uint32()), + } +} + +// assertAddr ensures that the two addresses match. The timestamp is not +// checked as it does not affect uniquely identifying a specific address. +func assertAddr(t *testing.T, got, expected *wire.NetAddress) { + if got.Services != expected.Services { + t.Fatalf("expected address services %v, got %v", + expected.Services, got.Services) + } + if !got.IP.Equal(expected.IP) { + t.Fatalf("expected address IP %v, got %v", expected.IP, got.IP) + } + if got.Port != expected.Port { + t.Fatalf("expected address port %d, got %d", expected.Port, + got.Port) + } +} + +// assertAddrs ensures that the manager's address cache matches the given +// expected addresses. +func assertAddrs(t *testing.T, addrMgr *AddrManager, + expectedAddrs map[string]*wire.NetAddress) { + + t.Helper() + + addrs := addrMgr.getAddresses() + + if len(addrs) != len(expectedAddrs) { + t.Fatalf("expected to find %d addresses, found %d", + len(expectedAddrs), len(addrs)) + } + + for _, addr := range addrs { + addrStr := NetAddressKey(addr) + expectedAddr, ok := expectedAddrs[addrStr] + if !ok { + t.Fatalf("expected to find address %v", addrStr) + } + + assertAddr(t, addr, expectedAddr) + } +} + +// TestAddrManagerSerialization ensures that we can properly serialize and +// deserialize the manager's current address cache. +func TestAddrManagerSerialization(t *testing.T) { + t.Parallel() + + // We'll start by creating our address manager backed by a temporary + // directory. + tempDir, err := ioutil.TempDir("", "addrmgr") + if err != nil { + t.Fatalf("unable to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + addrMgr := New(tempDir, nil) + + // We'll be adding 5 random addresses to the manager. + const numAddrs = 5 + + expectedAddrs := make(map[string]*wire.NetAddress, numAddrs) + for i := 0; i < numAddrs; i++ { + addr := randAddr(t) + expectedAddrs[NetAddressKey(addr)] = addr + addrMgr.AddAddress(addr, randAddr(t)) + } + + // Now that the addresses have been added, we should be able to retrieve + // them. + assertAddrs(t, addrMgr, expectedAddrs) + + // Then, we'll persist these addresses to disk and restart the address + // manager. + addrMgr.savePeers() + addrMgr = New(tempDir, nil) + + // Finally, we'll read all of the addresses from disk and ensure they + // match as expected. + addrMgr.loadPeers() + assertAddrs(t, addrMgr, expectedAddrs) +} + +// TestAddrManagerV1ToV2 ensures that we can properly upgrade the serialized +// version of the address manager from v1 to v2. +func TestAddrManagerV1ToV2(t *testing.T) { + t.Parallel() + + // We'll start by creating our address manager backed by a temporary + // directory. + tempDir, err := ioutil.TempDir("", "addrmgr") + if err != nil { + t.Fatalf("unable to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + addrMgr := New(tempDir, nil) + + // As we're interested in testing the upgrade path from v1 to v2, we'll + // override the manager's current version. + addrMgr.version = 1 + + // We'll be adding 5 random addresses to the manager. Since this is v1, + // each addresses' services will not be stored. + const numAddrs = 5 + + expectedAddrs := make(map[string]*wire.NetAddress, numAddrs) + for i := 0; i < numAddrs; i++ { + addr := randAddr(t) + expectedAddrs[NetAddressKey(addr)] = addr + addrMgr.AddAddress(addr, randAddr(t)) + } + + // Then, we'll persist these addresses to disk and restart the address + // manager - overriding its version back to v1. + addrMgr.savePeers() + addrMgr = New(tempDir, nil) + addrMgr.version = 1 + + // When we read all of the addresses back from disk, we should expect to + // find all of them, but their services will be set to a default of + // SFNodeNetwork since they were not previously stored. After ensuring + // that this default is set, we'll override each addresses' services + // with the original value from when they were created. + addrMgr.loadPeers() + addrs := addrMgr.getAddresses() + if len(addrs) != len(expectedAddrs) { + t.Fatalf("expected to find %d adddresses, found %d", + len(expectedAddrs), len(addrs)) + } + for _, addr := range addrs { + addrStr := NetAddressKey(addr) + expectedAddr, ok := expectedAddrs[addrStr] + if !ok { + t.Fatalf("expected to find address %v", addrStr) + } + + if addr.Services != wire.SFNodeNetwork { + t.Fatalf("expected address services to be %v, got %v", + wire.SFNodeNetwork, addr.Services) + } + + addrMgr.SetServices(addr, expectedAddr.Services) + } + + // We'll also bump up the manager's version to v2, which should signal + // that it should include the address services when persisting its + // state. + addrMgr.version = 2 + addrMgr.savePeers() + + // Finally, we'll recreate the manager and ensure that the services were + // persisted correctly. + addrMgr = New(tempDir, nil) + addrMgr.loadPeers() + assertAddrs(t, addrMgr, expectedAddrs) +} diff --git a/addrmgr/addrmanager_test.go b/addrmgr/addrmanager_test.go index fcfe845f..676913e2 100644 --- a/addrmgr/addrmanager_test.go +++ b/addrmgr/addrmanager_test.go @@ -262,7 +262,7 @@ func TestNeedMoreAddresses(t *testing.T) { var err error for i := 0; i < addrsToAdd; i++ { s := fmt.Sprintf("%d.%d.173.147:8333", i/128+60, i%128+60) - addrs[i], err = n.DeserializeNetAddress(s) + addrs[i], err = n.DeserializeNetAddress(s, wire.SFNodeNetwork) if err != nil { t.Errorf("Failed to turn %s into an address: %v", s, err) } @@ -290,7 +290,7 @@ func TestGood(t *testing.T) { var err error for i := 0; i < addrsToAdd; i++ { s := fmt.Sprintf("%d.173.147.%d:8333", i/64+60, i%64+60) - addrs[i], err = n.DeserializeNetAddress(s) + addrs[i], err = n.DeserializeNetAddress(s, wire.SFNodeNetwork) if err != nil { t.Errorf("Failed to turn %s into an address: %v", s, err) } diff --git a/addrmgr/knownaddress.go b/addrmgr/knownaddress.go index 4a8da83e..15469f37 100644 --- a/addrmgr/knownaddress.go +++ b/addrmgr/knownaddress.go @@ -33,6 +33,11 @@ func (ka *KnownAddress) LastAttempt() time.Time { return ka.lastattempt } +// Services returns the services supported by the peer with the known address. +func (ka *KnownAddress) Services() wire.ServiceFlag { + return ka.na.Services +} + // chance returns the selection probability for a known address. The priority // depends upon how recently the address has been seen, how recently it was last // attempted and how often attempts to connect to it have failed.