diff --git a/chain/log.go b/chain/log.go index d336eac..5eadfcd 100644 --- a/chain/log.go +++ b/chain/log.go @@ -4,7 +4,10 @@ package chain -import "github.com/btcsuite/btclog" +import ( + "github.com/btcsuite/btclog" + "github.com/lightninglabs/neutrino/query" +) // log is a logger that is initialized with no output filters. This // means the package will not perform any logging by default until the caller @@ -27,4 +30,5 @@ func DisableLog() { // using btclog. func UseLogger(logger btclog.Logger) { log = logger + query.UseLogger(logger) } diff --git a/chain/pruned_block_dispatcher.go b/chain/pruned_block_dispatcher.go new file mode 100644 index 0000000..6edb93a --- /dev/null +++ b/chain/pruned_block_dispatcher.go @@ -0,0 +1,625 @@ +package chain + +import ( + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "math/rand" + "net" + "sync" + "time" + + "github.com/btcsuite/btcd/blockchain" + "github.com/btcsuite/btcd/btcjson" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/peer" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/lightninglabs/neutrino/query" + "github.com/lightningnetwork/lnd/ticker" +) + +const ( + // defaultRefreshPeersInterval represents the default polling interval + // at which we attempt to refresh the set of known peers. + defaultRefreshPeersInterval = 30 * time.Second + + // defaultPeerReadyTimeout is the default amount of time we'll wait for + // a query peer to be ready to receive incoming block requests. Peers + // cannot respond to requests until the version exchange is completed + // upon connection establishment. + defaultPeerReadyTimeout = 15 * time.Second + + // requiredServices are the requires services we require any candidate + // peers to signal such that we can retrieve pruned blocks from them. + requiredServices = wire.SFNodeNetwork | wire.SFNodeWitness + + // prunedNodeService is the service bit signaled by pruned nodes on the + // network. + prunedNodeService wire.ServiceFlag = 1 << 11 +) + +// queryPeer represents a Bitcoin network peer that we'll query for blocks. +// The ready channel serves as a signal for us to know when we can be sending +// queries to the peer. Any messages received from the peer are sent through the +// msgsRecvd channel. +type queryPeer struct { + *peer.Peer + ready chan struct{} + msgsRecvd chan wire.Message + quit chan struct{} +} + +// signalUponDisconnect closes the peer's quit chan to signal it has +// disconnected. +func (p *queryPeer) signalUponDisconnect(f func()) { + go func() { + p.WaitForDisconnect() + close(p.quit) + f() + }() +} + +// SubscribeRecvMsg adds a OnRead subscription to the peer. All bitcoin messages +// received from this peer will be sent on the returned channel. A closure is +// also returned, that should be called to cancel the subscription. +// +// NOTE: This method exists to satisfy the query.Peer interface. +func (p *queryPeer) SubscribeRecvMsg() (<-chan wire.Message, func()) { + return p.msgsRecvd, func() {} +} + +// OnDisconnect returns a channel that will be closed once the peer disconnects. +// +// NOTE: This method exists to satisfy the query.Peer interface. +func (p *queryPeer) OnDisconnect() <-chan struct{} { + return p.quit +} + +// PrunedBlockDispatcherConfig encompasses all of the dependencies required by +// the PrunedBlockDispatcher to carry out its duties. +type PrunedBlockDispatcherConfig struct { + // ChainParams represents the parameters of the current active chain. + ChainParams *chaincfg.Params + + // NumTargetPeer represents the target number of peers we should + // maintain connections with. This exists to prevent establishing + // connections to all of the bitcoind's peers, which would be + // unnecessary and ineffecient. + NumTargetPeers int + + // Dial establishes connections to Bitcoin peers. This must support + // dialing peers running over Tor if the backend also supports it. + Dial func(string) (net.Conn, error) + + // GetPeers retrieves the active set of peers known to the backend node. + GetPeers func() ([]btcjson.GetPeerInfoResult, error) + + // PeerReadyTimeout is the amount of time we'll wait for a query peer to + // be ready to receive incoming block requests. Peers cannot respond to + // requests until the version exchange is completed upon connection + // establishment. + PeerReadyTimeout time.Duration + + // RefreshPeersTicker is the polling ticker that signals us when we + // should attempt to refresh the set of known peers. + RefreshPeersTicker ticker.Ticker + + // AllowSelfPeerConns is only used to allow the tests to bypass the peer + // self connection detecting and disconnect logic since they + // intentionally do so for testing purposes. + AllowSelfPeerConns bool + + // MaxRequestInvs dictates how many invs we should fit in a single + // getdata request to a peer. This only exists to facilitate the testing + // of a request spanning multiple getdata messages. + MaxRequestInvs int +} + +// PrunedBlockDispatcher enables a chain client to request blocks that the +// server has already pruned. This is done by connecting to the server's full +// node peers and querying them directly. Ideally, this is a capability +// supported by the server, though this is not yet possible with bitcoind. +type PrunedBlockDispatcher struct { + cfg PrunedBlockDispatcherConfig + + // workManager handles satisfying all of our incoming pruned block + // requests. + workManager *query.WorkManager + + // blocksQueried represents the set of pruned blocks we've been + // requested to query. Each block maps to a list of clients waiting to + // be notified once the block is received. + // + // NOTE: The blockMtx lock must always be held when accessing this + // field. + blocksQueried map[chainhash.Hash][]chan *wire.MsgBlock + blockMtx sync.Mutex + + // currentPeers represents the set of peers we're currently connected + // to. Each peer found here will have a worker spawned within the + // workManager to handle our queries. + // + // NOTE: The peerMtx lock must always be held when accessing this + // field. + currentPeers map[string]*peer.Peer + + // bannedPeers represents the set of peers who have sent us an invalid + // reply corresponding to a query. Peers within this set should not be + // dialed. + // + // NOTE: The peerMtx lock must always be held when accessing this + // field. + bannedPeers map[string]struct{} + peerMtx sync.Mutex + + // peersConnected is the channel through which we'll send new peers + // we've established connections to. + peersConnected chan query.Peer + + // timeSource provides a mechanism to add several time samples which are + // used to determine a median time which is then used as an offset to + // the local clock when validating blocks received from peers. + timeSource blockchain.MedianTimeSource + + quit chan struct{} + wg sync.WaitGroup +} + +// NewPrunedBlockDispatcher initializes a new PrunedBlockDispatcher instance +// backed by the given config. +func NewPrunedBlockDispatcher(cfg *PrunedBlockDispatcherConfig) ( + *PrunedBlockDispatcher, error) { + + if cfg.NumTargetPeers < 1 { + return nil, errors.New("config option NumTargetPeer must be >= 1") + } + if cfg.MaxRequestInvs > wire.MaxInvPerMsg { + return nil, fmt.Errorf("config option MaxRequestInvs must be "+ + "<= %v", wire.MaxInvPerMsg) + } + + peersConnected := make(chan query.Peer) + return &PrunedBlockDispatcher{ + cfg: *cfg, + workManager: query.New(&query.Config{ + ConnectedPeers: func() (<-chan query.Peer, func(), error) { + return peersConnected, func() {}, nil + }, + NewWorker: query.NewWorker, + Ranking: query.NewPeerRanking(), + }), + blocksQueried: make(map[chainhash.Hash][]chan *wire.MsgBlock), + currentPeers: make(map[string]*peer.Peer), + bannedPeers: make(map[string]struct{}), + peersConnected: peersConnected, + timeSource: blockchain.NewMedianTime(), + quit: make(chan struct{}), + }, nil +} + +// Start allows the PrunedBlockDispatcher to begin handling incoming block +// requests. +func (d *PrunedBlockDispatcher) Start() error { + log.Tracef("Starting pruned block dispatcher") + + if err := d.workManager.Start(); err != nil { + return err + } + + d.wg.Add(1) + go d.pollPeers() + + return nil +} + +// Stop stops the PrunedBlockDispatcher from accepting any more incoming block +// requests. +func (d *PrunedBlockDispatcher) Stop() { + log.Tracef("Stopping pruned block dispatcher") + + close(d.quit) + d.wg.Wait() + + _ = d.workManager.Stop() +} + +// pollPeers continuously polls the backend node for new peers to establish +// connections to. +func (d *PrunedBlockDispatcher) pollPeers() { + defer d.wg.Done() + + if err := d.connectToPeers(); err != nil { + log.Warnf("Unable to establish peer connections: %v", err) + } + + d.cfg.RefreshPeersTicker.Resume() + defer d.cfg.RefreshPeersTicker.Stop() + + for { + select { + case <-d.cfg.RefreshPeersTicker.Ticks(): + // Quickly determine if we need any more peer + // connections. If we don't, we'll wait for our next + // tick. + d.peerMtx.Lock() + peersNeeded := d.cfg.NumTargetPeers - len(d.currentPeers) + d.peerMtx.Unlock() + if peersNeeded <= 0 { + continue + } + + // If we do, attempt to establish connections until + // we've reached our target number. + if err := d.connectToPeers(); err != nil { + log.Warnf("Unable to establish peer "+ + "connections: %v", err) + continue + } + + case <-d.quit: + return + } + } +} + +// connectToPeers attempts to establish new peer connections until the target +// number is reached. Once a connection is successfully established, the peer is +// sent through the peersConnected channel to notify the internal workManager. +func (d *PrunedBlockDispatcher) connectToPeers() error { + // Refresh the list of peers our backend is currently connected to, and + // filter out any that do not meet our requirements. + peers, err := d.cfg.GetPeers() + if err != nil { + return err + } + peers, err = filterPeers(peers) + if err != nil { + return err + } + rand.Shuffle(len(peers), func(i, j int) { + peers[i], peers[j] = peers[j], peers[i] + }) + + // For each unbanned peer we don't already have a connection to, try to + // establish one, and if successful, notify the peer. + for _, peer := range peers { + d.peerMtx.Lock() + _, isBanned := d.bannedPeers[peer.Addr] + _, isConnected := d.currentPeers[peer.Addr] + d.peerMtx.Unlock() + if isBanned || isConnected { + continue + } + + queryPeer, err := d.newQueryPeer(peer) + if err != nil { + return fmt.Errorf("unable to configure query peer %v: "+ + "%v", peer.Addr, err) + } + if err := d.connectToPeer(queryPeer); err != nil { + log.Debugf("Failed connecting to peer %v: %v", + peer.Addr, err) + continue + } + + select { + case d.peersConnected <- queryPeer: + case <-d.quit: + return errors.New("shutting down") + } + + // If the new peer helped us reach our target number, we're done + // and can exit. + d.peerMtx.Lock() + d.currentPeers[queryPeer.Addr()] = queryPeer.Peer + numPeers := len(d.currentPeers) + d.peerMtx.Unlock() + if numPeers == d.cfg.NumTargetPeers { + break + } + } + + return nil +} + +// filterPeers filters out any peers which cannot handle arbitrary witness block +// requests, i.e., any peer which is not considered a segwit-enabled +// "full-node". +func filterPeers(peers []btcjson.GetPeerInfoResult) ( + []btcjson.GetPeerInfoResult, error) { + + var eligible []btcjson.GetPeerInfoResult + for _, peer := range peers { + rawServices, err := hex.DecodeString(peer.Services) + if err != nil { + return nil, err + } + services := wire.ServiceFlag(binary.BigEndian.Uint64(rawServices)) + + // Skip nodes that cannot serve full block witness data. + if services&requiredServices != requiredServices { + continue + } + // Skip pruned nodes. + if services&prunedNodeService == prunedNodeService { + continue + } + + eligible = append(eligible, peer) + } + + return eligible, nil +} + +// newQueryPeer creates a new peer instance configured to relay any received +// messages to the internal workManager. +func (d *PrunedBlockDispatcher) newQueryPeer( + peerInfo btcjson.GetPeerInfoResult) (*queryPeer, error) { + + ready := make(chan struct{}) + msgsRecvd := make(chan wire.Message) + + cfg := &peer.Config{ + ChainParams: d.cfg.ChainParams, + // We're not interested in transactions, so disable their relay. + DisableRelayTx: true, + Listeners: peer.MessageListeners{ + // Add the remote peer time as a sample for creating an + // offset against the local clock to keep the network + // time in sync. + OnVersion: func(p *peer.Peer, msg *wire.MsgVersion) *wire.MsgReject { + d.timeSource.AddTimeSample(p.Addr(), msg.Timestamp) + return nil + }, + // Register a callback to signal us when we can start + // querying the peer for blocks. + OnVerAck: func(*peer.Peer, *wire.MsgVerAck) { + close(ready) + }, + // Register a callback to signal us whenever the peer + // has sent us a block message. + OnRead: func(p *peer.Peer, _ int, msg wire.Message, err error) { + if err != nil { + return + } + + var block *wire.MsgBlock + switch msg := msg.(type) { + case *wire.MsgBlock: + block = msg + case *wire.MsgVersion, *wire.MsgVerAck: + return + default: + log.Debugf("Received unexpected message "+ + "%T from peer %v", msg, p.Addr()) + return + } + + select { + case msgsRecvd <- block: + case <-d.quit: + } + }, + }, + AllowSelfConns: true, + } + p, err := peer.NewOutboundPeer(cfg, peerInfo.Addr) + if err != nil { + return nil, err + } + + return &queryPeer{ + Peer: p, + ready: ready, + msgsRecvd: msgsRecvd, + quit: make(chan struct{}), + }, nil +} + +// connectToPeer attempts to establish a connection to the given peer and waits +// up to PeerReadyTimeout for the version exchange to complete so that we can +// begin sending it our queries. +func (d *PrunedBlockDispatcher) connectToPeer(peer *queryPeer) error { + conn, err := d.cfg.Dial(peer.Addr()) + if err != nil { + return err + } + peer.AssociateConnection(conn) + + select { + case <-peer.ready: + case <-time.After(d.cfg.PeerReadyTimeout): + peer.Disconnect() + return errors.New("timed out waiting for protocol negotiation") + case <-d.quit: + return errors.New("shutting down") + } + + // Remove the peer once it has disconnected. + peer.signalUponDisconnect(func() { + d.peerMtx.Lock() + delete(d.currentPeers, peer.Addr()) + d.peerMtx.Unlock() + }) + + return nil +} + +// banPeer bans a peer by disconnecting them and ensuring we don't reconnect. +func (d *PrunedBlockDispatcher) banPeer(peer string) { + d.peerMtx.Lock() + defer d.peerMtx.Unlock() + + d.bannedPeers[peer] = struct{}{} + if p, ok := d.currentPeers[peer]; ok { + p.Disconnect() + } +} + +// Query submits a request to query the information of the given blocks. +func (d *PrunedBlockDispatcher) Query(blocks []*chainhash.Hash, + opts ...query.QueryOption) (<-chan *wire.MsgBlock, <-chan error) { + + reqs, blockChan, err := d.newRequest(blocks) + if err != nil { + errChan := make(chan error, 1) + errChan <- err + return nil, errChan + } + + var errChan chan error + if len(reqs) > 0 { + errChan = d.workManager.Query(reqs, opts...) + } + return blockChan, errChan +} + +// newRequest construct a new query request for the given blocks to submit to +// the internal workManager. A channel is also returned through which the +// requested blocks are sent through. +func (d *PrunedBlockDispatcher) newRequest(blocks []*chainhash.Hash) ( + []*query.Request, <-chan *wire.MsgBlock, error) { + + // Make sure the channel is buffered enough to handle all blocks. + blockChan := make(chan *wire.MsgBlock, len(blocks)) + + d.blockMtx.Lock() + defer d.blockMtx.Unlock() + + // Each GetData message can only include up to MaxRequestInvs invs, + // and each block consumes a single inv. + var ( + reqs []*query.Request + getData *wire.MsgGetData + ) + for i, block := range blocks { + if getData == nil { + getData = wire.NewMsgGetData() + } + + if _, ok := d.blocksQueried[*block]; !ok { + log.Debugf("Queuing new block %v for request", *block) + inv := wire.NewInvVect(wire.InvTypeBlock, block) + if err := getData.AddInvVect(inv); err != nil { + return nil, nil, err + } + } else { + log.Debugf("Received new request for pending query of "+ + "block %v", *block) + } + + d.blocksQueried[*block] = append( + d.blocksQueried[*block], blockChan, + ) + + // If we have any invs to request, or we've reached the maximum + // allowed, queue the getdata message as is, and proceed to the + // next if any. + if (len(getData.InvList) > 0 && i == len(blocks)-1) || + len(getData.InvList) == d.cfg.MaxRequestInvs { + + reqs = append(reqs, &query.Request{ + Req: getData, + HandleResp: d.handleResp, + }) + getData = nil + } + } + + return reqs, blockChan, nil +} + +// handleResp is a response handler that will be called for every message +// received from the peer that the request was made to. It should validate the +// response against the request made, and return a Progress indicating whether +// the request was answered by this particular response. +// +// NOTE: Since the worker's job queue will be stalled while this method is +// running, it should not be doing any expensive operations. It should validate +// the response and immediately return the progress. The response should be +// handed off to another goroutine for processing. +func (d *PrunedBlockDispatcher) handleResp(req, resp wire.Message, + peer string) query.Progress { + + // We only expect MsgBlock as replies. + block, ok := resp.(*wire.MsgBlock) + if !ok { + return query.Progress{ + Progressed: false, + Finished: false, + } + } + + // We only serve MsgGetData requests. + getData, ok := req.(*wire.MsgGetData) + if !ok { + return query.Progress{ + Progressed: false, + Finished: false, + } + } + + // Check that we've actually queried for this block and validate it. + blockHash := block.BlockHash() + d.blockMtx.Lock() + blockChans, ok := d.blocksQueried[blockHash] + if !ok { + d.blockMtx.Unlock() + return query.Progress{ + Progressed: false, + Finished: false, + } + } + + err := blockchain.CheckBlockSanity( + btcutil.NewBlock(block), d.cfg.ChainParams.PowLimit, + d.timeSource, + ) + if err != nil { + d.blockMtx.Unlock() + + log.Warnf("Received invalid block %v from peer %v: %v", + blockHash, peer, err) + d.banPeer(peer) + + return query.Progress{ + Progressed: false, + Finished: false, + } + } + + // Once validated, we can safely remove it. + delete(d.blocksQueried, blockHash) + + // Check whether we have any other pending blocks we've yet to receive. + // If we do, we'll mark the response as progressing our query, but not + // completing it yet. + progress := query.Progress{Progressed: true, Finished: true} + for _, inv := range getData.InvList { + if _, ok := d.blocksQueried[inv.Hash]; ok { + progress.Finished = false + break + } + } + d.blockMtx.Unlock() + + // Launch a goroutine to notify all clients of the block as we don't + // want to potentially block our workManager. + d.wg.Add(1) + go func() { + defer d.wg.Done() + + for _, blockChan := range blockChans { + select { + case blockChan <- block: + case <-d.quit: + return + } + } + }() + + return progress +} diff --git a/chain/pruned_block_dispatcher_test.go b/chain/pruned_block_dispatcher_test.go new file mode 100644 index 0000000..06cd7c8 --- /dev/null +++ b/chain/pruned_block_dispatcher_test.go @@ -0,0 +1,590 @@ +package chain + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "net" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/btcsuite/btcd/btcjson" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/peer" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/ticker" + "github.com/stretchr/testify/require" +) + +func init() { + b := btclog.NewBackend(os.Stdout) + l := b.Logger("") + l.SetLevel(btclog.LevelTrace) + UseLogger(l) +} + +var ( + addrCounter int32 // Increased atomically. + + chainParams = chaincfg.RegressionNetParams +) + +func nextAddr() string { + port := atomic.AddInt32(&addrCounter, 1) + return fmt.Sprintf("10.0.0.1:%d", port) +} + +// prunedBlockDispatcherHarness is a harness used to facilitate the testing of the +// PrunedBlockDispatcher. +type prunedBlockDispatcherHarness struct { + t *testing.T + + dispatcher *PrunedBlockDispatcher + + hashes []*chainhash.Hash + blocks map[chainhash.Hash]*wire.MsgBlock + + peerMtx sync.Mutex + peers map[string]*peer.Peer + localConns map[string]net.Conn // Connections to peers. + remoteConns map[string]net.Conn // Connections from peers. + + dialedPeer chan struct{} + queriedPeer chan struct{} + blocksQueried map[chainhash.Hash]int + + shouldReply uint32 // 0 == true, 1 == false, 2 == invalid reply +} + +// newNetworkBlockTestHarness initializes a new PrunedBlockDispatcher test harness +// backed by a custom chain and peers. +func newNetworkBlockTestHarness(t *testing.T, numBlocks, + numPeers, numWorkers uint32) *prunedBlockDispatcherHarness { + + h := &prunedBlockDispatcherHarness{ + t: t, + peers: make(map[string]*peer.Peer, numPeers), + localConns: make(map[string]net.Conn, numPeers), + remoteConns: make(map[string]net.Conn, numPeers), + dialedPeer: make(chan struct{}), + queriedPeer: make(chan struct{}), + blocksQueried: make(map[chainhash.Hash]int), + } + + h.hashes, h.blocks = genBlockChain(numBlocks) + for i := uint32(0); i < numPeers; i++ { + h.addPeer() + } + + dial := func(addr string) (net.Conn, error) { + go func() { + h.dialedPeer <- struct{}{} + }() + + h.peerMtx.Lock() + defer h.peerMtx.Unlock() + + localConn, ok := h.localConns[addr] + if !ok { + return nil, fmt.Errorf("local conn %v not found", addr) + } + remoteConn, ok := h.remoteConns[addr] + if !ok { + return nil, fmt.Errorf("remote conn %v not found", addr) + } + + h.peers[addr].AssociateConnection(remoteConn) + return localConn, nil + } + + var err error + h.dispatcher, err = NewPrunedBlockDispatcher(&PrunedBlockDispatcherConfig{ + ChainParams: &chainParams, + NumTargetPeers: int(numWorkers), + Dial: dial, + GetPeers: func() ([]btcjson.GetPeerInfoResult, error) { + h.peerMtx.Lock() + defer h.peerMtx.Unlock() + + res := make([]btcjson.GetPeerInfoResult, 0, len(h.peers)) + for addr, peer := range h.peers { + var rawServices [8]byte + binary.BigEndian.PutUint64( + rawServices[:], uint64(peer.Services()), + ) + + res = append(res, btcjson.GetPeerInfoResult{ + Addr: addr, + Services: hex.EncodeToString(rawServices[:]), + }) + } + + return res, nil + }, + PeerReadyTimeout: time.Hour, + RefreshPeersTicker: ticker.NewForce(time.Hour), + AllowSelfPeerConns: true, + MaxRequestInvs: wire.MaxInvPerMsg, + }) + require.NoError(t, err) + + return h +} + +// start starts the PrunedBlockDispatcher and asserts that connections are made +// to all available peers. +func (h *prunedBlockDispatcherHarness) start() { + h.t.Helper() + + err := h.dispatcher.Start() + require.NoError(h.t, err) + + h.peerMtx.Lock() + numPeers := len(h.peers) + h.peerMtx.Unlock() + + for i := 0; i < numPeers; i++ { + h.assertPeerDialed() + } +} + +// stop stops the PrunedBlockDispatcher and asserts that all internal fields of +// the harness have been properly consumed. +func (h *prunedBlockDispatcherHarness) stop() { + h.dispatcher.Stop() + + select { + case <-h.dialedPeer: + h.t.Fatal("did not consume all dialedPeer signals") + default: + } + + select { + case <-h.queriedPeer: + h.t.Fatal("did not consume all queriedPeer signals") + default: + } + + require.Empty(h.t, h.blocksQueried) +} + +// addPeer adds a new random peer available for use by the +// PrunedBlockDispatcher. +func (h *prunedBlockDispatcherHarness) addPeer() string { + addr := nextAddr() + + h.peerMtx.Lock() + defer h.peerMtx.Unlock() + + h.resetPeer(addr) + return addr +} + +// resetPeer resets the internal peer connection state allowing the +// PrunedBlockDispatcher to establish a mock connection to it. +func (h *prunedBlockDispatcherHarness) resetPeer(addr string) { + h.peers[addr] = h.newPeer() + + // Establish a mock connection between us and each peer. + inConn, outConn := pipe( + &conn{localAddr: addr, remoteAddr: "10.0.0.1:8333"}, + &conn{localAddr: "10.0.0.1:8333", remoteAddr: addr}, + ) + h.localConns[addr] = outConn + h.remoteConns[addr] = inConn +} + +// newPeer returns a new properly configured peer.Peer instance that will be +// used by the PrunedBlockDispatcher. +func (h *prunedBlockDispatcherHarness) newPeer() *peer.Peer { + return peer.NewInboundPeer(&peer.Config{ + ChainParams: &chainParams, + DisableRelayTx: true, + Listeners: peer.MessageListeners{ + OnGetData: func(p *peer.Peer, msg *wire.MsgGetData) { + go func() { + h.queriedPeer <- struct{}{} + }() + + for _, inv := range msg.InvList { + // Invs should always be for blocks. + require.Equal(h.t, wire.InvTypeBlock, inv.Type) + + // Invs should always be for known blocks. + block, ok := h.blocks[inv.Hash] + require.True(h.t, ok) + + switch atomic.LoadUint32(&h.shouldReply) { + // Don't reply if requested. + case 1: + continue + // Make the block invalid and send it. + case 2: + block = produceInvalidBlock(block) + } + + go p.QueueMessage(block, nil) + } + }, + }, + Services: wire.SFNodeNetwork | wire.SFNodeWitness, + AllowSelfConns: true, + }) +} + +// query requests the given blocks from the PrunedBlockDispatcher. +func (h *prunedBlockDispatcherHarness) query(blocks []*chainhash.Hash) ( + <-chan *wire.MsgBlock, <-chan error) { + + h.t.Helper() + + blockChan, errChan := h.dispatcher.Query(blocks) + select { + case err := <-errChan: + require.NoError(h.t, err) + default: + } + + for _, block := range blocks { + h.blocksQueried[*block]++ + } + + return blockChan, errChan +} + +// disablePeerReplies prevents the query peer from replying. +func (h *prunedBlockDispatcherHarness) disablePeerReplies() { + atomic.StoreUint32(&h.shouldReply, 1) +} + +// enablePeerReplies allows the query peer to reply. +func (h *prunedBlockDispatcherHarness) enablePeerReplies() { + atomic.StoreUint32(&h.shouldReply, 0) +} + +// enableInvalidPeerReplies +func (h *prunedBlockDispatcherHarness) enableInvalidPeerReplies() { + atomic.StoreUint32(&h.shouldReply, 2) +} + +// refreshPeers forces the RefreshPeersTicker to fire. +func (h *prunedBlockDispatcherHarness) refreshPeers() { + h.t.Helper() + + h.dispatcher.cfg.RefreshPeersTicker.(*ticker.Force).Force <- time.Now() +} + +// disconnectPeer simulates a peer disconnecting from the PrunedBlockDispatcher. +func (h *prunedBlockDispatcherHarness) disconnectPeer(addr string) { + h.t.Helper() + + h.peerMtx.Lock() + defer h.peerMtx.Unlock() + + require.Contains(h.t, h.peers, addr) + + // Obtain the current number of peers before disconnecting such that we + // can block until the peer has been fully disconnected. + h.dispatcher.peerMtx.Lock() + numPeers := len(h.dispatcher.currentPeers) + h.dispatcher.peerMtx.Unlock() + + h.peers[addr].Disconnect() + + require.Eventually(h.t, func() bool { + h.dispatcher.peerMtx.Lock() + defer h.dispatcher.peerMtx.Unlock() + return len(h.dispatcher.currentPeers) == numPeers-1 + }, time.Second, 200*time.Millisecond) + + // Reset the peer connection state to allow connections to them again. + h.resetPeer(addr) +} + +// assertPeerDialed asserts that a connection was made to the given peer. +func (h *prunedBlockDispatcherHarness) assertPeerDialed() { + h.t.Helper() + + select { + case <-h.dialedPeer: + case <-time.After(5 * time.Second): + h.t.Fatalf("expected peer to be dialed") + } +} + +// assertPeerQueried asserts that query was sent to the given peer. +func (h *prunedBlockDispatcherHarness) assertPeerQueried() { + h.t.Helper() + + select { + case <-h.queriedPeer: + case <-time.After(5 * time.Second): + h.t.Fatalf("expected a peer to be queried") + } +} + +// assertPeerReplied asserts that the query peer replies with a block the +// PrunedBlockDispatcher queried for. +func (h *prunedBlockDispatcherHarness) assertPeerReplied( + blockChan <-chan *wire.MsgBlock, errChan <-chan error, + expectCompletionSignal bool) { + + h.t.Helper() + + select { + case block := <-blockChan: + blockHash := block.BlockHash() + _, ok := h.blocksQueried[blockHash] + require.True(h.t, ok) + + expBlock, ok := h.blocks[blockHash] + require.True(h.t, ok) + require.Equal(h.t, expBlock, block) + + // Decrement how many clients queried the same block. Once we + // have none left, remove it from the map. + h.blocksQueried[blockHash]-- + if h.blocksQueried[blockHash] == 0 { + delete(h.blocksQueried, blockHash) + } + + case <-time.After(5 * time.Second): + select { + case err := <-errChan: + h.t.Fatalf("received unexpected error send: %v", err) + default: + } + h.t.Fatal("expected reply from peer") + } + + // If we should expect a nil error to be sent by the internal + // workManager to signal completion of the request, wait for it now. + if expectCompletionSignal { + select { + case err := <-errChan: + require.NoError(h.t, err) + case <-time.After(5 * time.Second): + h.t.Fatal("expected nil err to signal completion") + } + } +} + +// assertNoPeerDialed asserts that the PrunedBlockDispatcher hasn't established +// a new peer connection. +func (h *prunedBlockDispatcherHarness) assertNoPeerDialed() { + h.t.Helper() + + select { + case peer := <-h.dialedPeer: + h.t.Fatalf("unexpected connection established with peer %v", peer) + case <-time.After(2 * time.Second): + } +} + +// assertNoReply asserts that the peer hasn't replied to a query. +func (h *prunedBlockDispatcherHarness) assertNoReply( + blockChan <-chan *wire.MsgBlock, errChan <-chan error) { + + h.t.Helper() + + select { + case block := <-blockChan: + h.t.Fatalf("received unexpected block %v", block.BlockHash()) + case err := <-errChan: + h.t.Fatalf("received unexpected error send: %v", err) + case <-time.After(2 * time.Second): + } +} + +// TestPrunedBlockDispatcherQuerySameBlock tests that client requests for the +// same block result in only fetching the block once while pending. +func TestPrunedBlockDispatcherQuerySameBlock(t *testing.T) { + t.Parallel() + + const numBlocks = 1 + const numPeers = 5 + const numRequests = numBlocks * numPeers + + h := newNetworkBlockTestHarness(t, numBlocks, numPeers, numPeers) + h.start() + defer h.stop() + + // Queue all the block requests one by one. + blockChans := make([]<-chan *wire.MsgBlock, 0, numRequests) + errChans := make([]<-chan error, 0, numRequests) + for i := 0; i < numRequests; i++ { + blockChan, errChan := h.query(h.hashes) + blockChans = append(blockChans, blockChan) + errChans = append(errChans, errChan) + } + + // We should only see one query. + h.assertPeerQueried() + for i := 0; i < numRequests; i++ { + h.assertPeerReplied(blockChans[i], errChans[i], i == 0) + } +} + +// TestPrunedBlockDispatcherMultipleGetData tests that a client requesting blocks +// that span across multiple queries works as intended. +func TestPrunedBlockDispatcherMultipleGetData(t *testing.T) { + t.Parallel() + + const maxRequestInvs = 5 + const numBlocks = (maxRequestInvs * 5) + 1 + + h := newNetworkBlockTestHarness(t, numBlocks, 1, 1) + h.dispatcher.cfg.MaxRequestInvs = maxRequestInvs + h.start() + defer h.stop() + + // Request all blocks. + blockChan, errChan := h.query(h.hashes) + + // Since we have more blocks than can fit in a single GetData message, + // we should expect multiple queries. For each query, we should expect + // wire.MaxInvPerMsg replies until we've received all of them. + blocksRecvd := 0 + numMsgs := (numBlocks / maxRequestInvs) + if numBlocks%wire.MaxInvPerMsg > 0 { + numMsgs++ + } + for i := 0; i < numMsgs; i++ { + h.assertPeerQueried() + for j := 0; j < maxRequestInvs; j++ { + expectCompletionSignal := blocksRecvd == numBlocks-1 + h.assertPeerReplied( + blockChan, errChan, expectCompletionSignal, + ) + + blocksRecvd++ + if blocksRecvd == numBlocks { + break + } + } + } +} + +// TestPrunedBlockDispatcherMultipleQueryPeers tests that client requests are +// distributed across multiple query peers. +func TestPrunedBlockDispatcherMultipleQueryPeers(t *testing.T) { + t.Parallel() + + const numBlocks = 10 + const numPeers = numBlocks / 2 + + h := newNetworkBlockTestHarness(t, numBlocks, numPeers, numPeers) + h.start() + defer h.stop() + + // Queue all the block requests one by one. + blockChans := make([]<-chan *wire.MsgBlock, 0, numBlocks) + errChans := make([]<-chan error, 0, numBlocks) + for i := 0; i < numBlocks; i++ { + blockChan, errChan := h.query(h.hashes[i : i+1]) + blockChans = append(blockChans, blockChan) + errChans = append(errChans, errChan) + } + + // We should see one query per block. + for i := 0; i < numBlocks; i++ { + h.assertPeerQueried() + h.assertPeerReplied(blockChans[i], errChans[i], i == numBlocks-1) + } +} + +// TestPrunedBlockDispatcherPeerPoller ensures that the peer poller can detect +// when more connections are required to satisfy a request. +func TestPrunedBlockDispatcherPeerPoller(t *testing.T) { + t.Parallel() + + // Initialize our harness as usual, but don't create any peers yet. + h := newNetworkBlockTestHarness(t, 1, 0, 2) + h.start() + defer h.stop() + + // We shouldn't see any peers dialed since we don't have any. + h.assertNoPeerDialed() + + // We'll then query for a block. + blockChan, errChan := h.query(h.hashes) + + // Refresh our peers. This would dial some peers, but we don't have any + // yet. + h.refreshPeers() + h.assertNoPeerDialed() + + // Add a new peer and force a refresh. We should see the peer be dialed. + // We'll disable replies for now, as we'll want to test the disconnect + // case. + h.disablePeerReplies() + peer := h.addPeer() + h.refreshPeers() + h.assertPeerDialed() + h.assertPeerQueried() + + // Disconnect our peer and re-enable replies. + h.disconnectPeer(peer) + h.enablePeerReplies() + h.assertNoReply(blockChan, errChan) + + // Force a refresh once again. Since the peer has disconnected, a new + // connection should be made and the peer should be queried again. + h.refreshPeers() + h.assertPeerDialed() + h.assertPeerQueried() + + // Refresh our peers again. We can afford to have one more query peer, + // but there isn't another one available. We also shouldn't dial the one + // we're currently connected to again. + h.refreshPeers() + h.assertNoPeerDialed() + + // Now that we know we've connected to the peer, we should be able to + // receive their response. + h.assertPeerReplied(blockChan, errChan, true) +} + +// TestPrunedBlockDispatcherInvalidBlock ensures that validation is performed on +// blocks received from peers, and that any peers which have sent an invalid +// block are banned and not connected to. +func TestPrunedBlockDispatcherInvalidBlock(t *testing.T) { + t.Parallel() + + h := newNetworkBlockTestHarness(t, 1, 1, 1) + h.start() + defer h.stop() + + // We'll start the test by signaling our peer to send an invalid block. + h.enableInvalidPeerReplies() + + // We'll then query for a block. We shouldn't see a response as the + // block should have failed validation. + blockChan, errChan := h.query(h.hashes) + h.assertPeerQueried() + h.assertNoReply(blockChan, errChan) + + // Since the peer sent us an invalid block, they should have been + // disconnected and banned. Refreshing our peers shouldn't result in a + // new connection attempt because we don't have any other peers + // available. + h.refreshPeers() + h.assertNoPeerDialed() + + // Signal to our peers to send valid replies and add a new peer. + h.enablePeerReplies() + _ = h.addPeer() + + // Force a refresh, which should cause our new peer to be dialed and + // queried. We expect them to send a valid block and fulfill our + // request. + h.refreshPeers() + h.assertPeerDialed() + h.assertPeerQueried() + h.assertPeerReplied(blockChan, errChan, true) +} diff --git a/chain/utils_test.go b/chain/utils_test.go new file mode 100644 index 0000000..2c41c79 --- /dev/null +++ b/chain/utils_test.go @@ -0,0 +1,229 @@ +package chain + +import ( + "fmt" + "io" + "math" + "net" + "runtime" + "sync" + "time" + + "github.com/btcsuite/btcd/blockchain" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" +) + +// conn mocks a network connection by implementing the net.Conn interface. It is +// used to test peer connection without actually opening a network connection. +type conn struct { + io.Reader + io.Writer + io.Closer + localAddr string + remoteAddr string +} + +func (c conn) LocalAddr() net.Addr { + return &addr{"tcp", c.localAddr} +} +func (c conn) RemoteAddr() net.Addr { + return &addr{"tcp", c.remoteAddr} +} +func (c conn) SetDeadline(t time.Time) error { return nil } +func (c conn) SetReadDeadline(t time.Time) error { return nil } +func (c conn) SetWriteDeadline(t time.Time) error { return nil } + +// addr mocks a network address. +type addr struct { + net, address string +} + +func (m addr) Network() string { return m.net } +func (m addr) String() string { return m.address } + +// pipe turns two mock connections into a full-duplex connection similar to +// net.Pipe to allow pipe's with (fake) addresses. +func pipe(c1, c2 *conn) (*conn, *conn) { + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + + c1.Writer = w1 + c1.Closer = w1 + c2.Reader = r1 + c1.Reader = r2 + c2.Writer = w2 + c2.Closer = w2 + + return c1, c2 +} + +// calcMerkleRoot creates a merkle tree from the slice of transactions and +// returns the root of the tree. +// +// This function was copied from: +// https://github.com/btcsuite/btcd/blob/36a96f6a0025b6aeaebe4106821c2d46ee4be8d4/blockchain/fullblocktests/generate.go#L303 +func calcMerkleRoot(txns []*wire.MsgTx) chainhash.Hash { + if len(txns) == 0 { + return chainhash.Hash{} + } + + utilTxns := make([]*btcutil.Tx, 0, len(txns)) + for _, tx := range txns { + utilTxns = append(utilTxns, btcutil.NewTx(tx)) + } + merkles := blockchain.BuildMerkleTreeStore(utilTxns, false) + return *merkles[len(merkles)-1] +} + +// solveBlock attempts to find a nonce which makes the passed block header hash +// to a value less than the target difficulty. When a successful solution is +// found true is returned and the nonce field of the passed header is updated +// with the solution. False is returned if no solution exists. +// +// This function was copied from: +// https://github.com/btcsuite/btcd/blob/36a96f6a0025b6aeaebe4106821c2d46ee4be8d4/blockchain/fullblocktests/generate.go#L324 +func solveBlock(header *wire.BlockHeader) bool { + // sbResult is used by the solver goroutines to send results. + type sbResult struct { + found bool + nonce uint32 + } + + // Make sure all spawned goroutines finish executing before returning. + var wg sync.WaitGroup + defer func() { + wg.Wait() + }() + + // solver accepts a block header and a nonce range to test. It is + // intended to be run as a goroutine. + targetDifficulty := blockchain.CompactToBig(header.Bits) + quit := make(chan bool) + results := make(chan sbResult) + solver := func(hdr wire.BlockHeader, startNonce, stopNonce uint32) { + defer wg.Done() + + // We need to modify the nonce field of the header, so make sure + // we work with a copy of the original header. + for i := startNonce; i >= startNonce && i <= stopNonce; i++ { + select { + case <-quit: + return + default: + hdr.Nonce = i + hash := hdr.BlockHash() + if blockchain.HashToBig(&hash).Cmp( + targetDifficulty) <= 0 { + + select { + case results <- sbResult{true, i}: + case <-quit: + } + + return + } + } + } + + select { + case results <- sbResult{false, 0}: + case <-quit: + } + } + + startNonce := uint32(1) + stopNonce := uint32(math.MaxUint32) + numCores := uint32(runtime.NumCPU()) + noncesPerCore := (stopNonce - startNonce) / numCores + wg.Add(int(numCores)) + for i := uint32(0); i < numCores; i++ { + rangeStart := startNonce + (noncesPerCore * i) + rangeStop := startNonce + (noncesPerCore * (i + 1)) - 1 + if i == numCores-1 { + rangeStop = stopNonce + } + go solver(*header, rangeStart, rangeStop) + } + for i := uint32(0); i < numCores; i++ { + result := <-results + if result.found { + close(quit) + header.Nonce = result.nonce + return true + } + } + + return false +} + +// genBlockChain generates a test chain with the given number of blocks. +func genBlockChain(numBlocks uint32) ([]*chainhash.Hash, map[chainhash.Hash]*wire.MsgBlock) { + prevHash := chainParams.GenesisHash + prevHeader := &chainParams.GenesisBlock.Header + + hashes := make([]*chainhash.Hash, numBlocks) + blocks := make(map[chainhash.Hash]*wire.MsgBlock, numBlocks) + + // Each block contains three transactions, including the coinbase + // transaction. Each non-coinbase transaction spends outputs from + // the previous block. We also need to produce blocks that succeed + // validation through blockchain.CheckBlockSanity. + script := []byte{0x01, 0x01} + createTx := func(prevOut wire.OutPoint) *wire.MsgTx { + return &wire.MsgTx{ + TxIn: []*wire.TxIn{{ + PreviousOutPoint: prevOut, + SignatureScript: script, + }}, + TxOut: []*wire.TxOut{{PkScript: script}}, + } + } + for i := uint32(0); i < numBlocks; i++ { + txs := []*wire.MsgTx{ + createTx(wire.OutPoint{Index: wire.MaxPrevOutIndex}), + createTx(wire.OutPoint{Hash: *prevHash, Index: 0}), + createTx(wire.OutPoint{Hash: *prevHash, Index: 1}), + } + header := &wire.BlockHeader{ + Version: 1, + PrevBlock: *prevHash, + MerkleRoot: calcMerkleRoot(txs), + Timestamp: prevHeader.Timestamp.Add(10 * time.Minute), + Bits: chainParams.PowLimitBits, + Nonce: 0, + } + if !solveBlock(header) { + panic(fmt.Sprintf("could not solve block at idx %v", i)) + } + block := &wire.MsgBlock{ + Header: *header, + Transactions: txs, + } + + blockHash := block.BlockHash() + hashes[i] = &blockHash + blocks[blockHash] = block + + prevHash = &blockHash + prevHeader = header + } + + return hashes, blocks +} + +// producesInvalidBlock produces a copy of the block that duplicates the last +// transaction. When the block has an odd number of transactions, this results +// in the invalid block maintaining the same hash as the valid block. +func produceInvalidBlock(block *wire.MsgBlock) *wire.MsgBlock { + numTxs := len(block.Transactions) + lastTx := block.Transactions[numTxs-1] + blockCopy := &wire.MsgBlock{ + Header: block.Header, + Transactions: make([]*wire.MsgTx, numTxs), + } + copy(blockCopy.Transactions, block.Transactions) + blockCopy.AddTransaction(lastTx) + return blockCopy +}