get rid of ReadDeadline, switch to updated stopOnce

This commit is contained in:
Alex Grintsvayg 2018-05-24 17:49:43 -04:00
parent 57a7c23787
commit 388c1128ec
9 changed files with 43 additions and 58 deletions

View file

@ -6,7 +6,7 @@ import (
"encoding/hex" "encoding/hex"
"strings" "strings"
"github.com/lbryio/errors.go" "github.com/lbryio/lbry.go/errors"
"github.com/lyoshenka/bencode" "github.com/lyoshenka/bencode"
) )

View file

@ -121,8 +121,8 @@ func (b *BootstrapNode) get(limit int) []Contact {
// ping pings a node. if the node responds, it is added to the list. otherwise, it is removed // ping pings a node. if the node responds, it is added to the list. otherwise, it is removed
func (b *BootstrapNode) ping(c Contact) { func (b *BootstrapNode) ping(c Contact) {
b.stopWG.Add(1) b.stop.Add(1)
defer b.stopWG.Done() defer b.stop.Done()
resCh, cancel := b.SendCancelable(c, Request{Method: pingMethod}) resCh, cancel := b.SendCancelable(c, Request{Method: pingMethod})

View file

@ -7,7 +7,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/lbryio/errors.go" "github.com/lbryio/lbry.go/errors"
"github.com/lbryio/lbry.go/stopOnce" "github.com/lbryio/lbry.go/stopOnce"
"github.com/spf13/cast" "github.com/spf13/cast"

View file

@ -7,7 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/lbryio/errors.go" "github.com/lbryio/lbry.go/errors"
"github.com/lyoshenka/bencode" "github.com/lyoshenka/bencode"
"github.com/spf13/cast" "github.com/spf13/cast"

View file

@ -40,6 +40,8 @@ type Node struct {
id Bitmap id Bitmap
// UDP connection for sending and receiving data // UDP connection for sending and receiving data
conn UDPConn conn UDPConn
// true if we've closed the connection on purpose
connClosed bool
// token manager // token manager
tokens *tokenManager tokens *tokenManager
@ -56,8 +58,7 @@ type Node struct {
requestHandler RequestHandlerFunc requestHandler RequestHandlerFunc
// stop the node neatly and clean up after itself // stop the node neatly and clean up after itself
stop *stopOnce.Stopper stop *stopOnce.Stopper
stopWG *sync.WaitGroup
} }
// New returns a Node pointer. // New returns a Node pointer.
@ -71,7 +72,6 @@ func NewNode(id Bitmap) *Node {
transactions: make(map[messageID]*transaction), transactions: make(map[messageID]*transaction),
stop: stopOnce.New(), stop: stopOnce.New(),
stopWG: &sync.WaitGroup{},
tokens: &tokenManager{}, tokens: &tokenManager{},
} }
} }
@ -80,43 +80,31 @@ func NewNode(id Bitmap) *Node {
func (n *Node) Connect(conn UDPConn) error { func (n *Node) Connect(conn UDPConn) error {
n.conn = conn n.conn = conn
//if dht.conf.PrintState > 0 {
// go func() {
// t := time.NewTicker(dht.conf.PrintState)
// for {
// dht.PrintState()
// select {
// case <-t.C:
// case <-dht.stop.Ch():
// return
// }
// }
// }()
//}
n.tokens.Start(tokenSecretRotationInterval) n.tokens.Start(tokenSecretRotationInterval)
go func() {
// stop tokens and close the connection when we're shutting down
<-n.stop.Ch()
n.tokens.Stop()
n.connClosed = true
n.conn.Close()
}()
packets := make(chan packet) packets := make(chan packet)
go func() { go func() {
n.stopWG.Add(1) n.stop.Add(1)
defer n.stopWG.Done() defer n.stop.Done()
buf := make([]byte, udpMaxMessageLength) buf := make([]byte, udpMaxMessageLength)
for { for {
select {
case <-n.stop.Ch():
return
default:
}
n.conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) // need this to periodically check shutdown chan
bytesRead, raddr, err := n.conn.ReadFromUDP(buf) bytesRead, raddr, err := n.conn.ReadFromUDP(buf)
if err != nil { if err != nil {
if e, ok := err.(net.Error); !ok || !e.Timeout() { if n.connClosed {
log.Errorf("udp read error: %v", err) return
} }
log.Errorf("udp read error: %v", err)
continue continue
} else if raddr == nil { } else if raddr == nil {
log.Errorf("udp read with no raddr") log.Errorf("udp read with no raddr")
@ -129,13 +117,14 @@ func (n *Node) Connect(conn UDPConn) error {
select { // needs select here because packet consumer can quit and the packets channel gets filled up and blocks select { // needs select here because packet consumer can quit and the packets channel gets filled up and blocks
case packets <- packet{data: data, raddr: raddr}: case packets <- packet{data: data, raddr: raddr}:
case <-n.stop.Ch(): case <-n.stop.Ch():
return
} }
} }
}() }()
go func() { go func() {
n.stopWG.Add(1) n.stop.Add(1)
defer n.stopWG.Done() defer n.stop.Done()
var pkt packet var pkt packet
@ -157,10 +146,7 @@ func (n *Node) Connect(conn UDPConn) error {
// Shutdown shuts down the node // Shutdown shuts down the node
func (n *Node) Shutdown() { func (n *Node) Shutdown() {
log.Debugf("[%s] node shutting down", n.id.HexShort()) log.Debugf("[%s] node shutting down", n.id.HexShort())
n.stop.Stop() n.stop.StopAndWait()
n.stopWG.Wait()
n.tokens.Stop()
n.conn.Close()
log.Debugf("[%s] node stopped", n.id.HexShort()) log.Debugf("[%s] node stopped", n.id.HexShort())
} }
@ -316,7 +302,7 @@ func (n *Node) sendMessage(addr *net.UDPAddr, data Message) error {
log.Debugf("[%s] (%d bytes) %s", n.id.HexShort(), len(encoded), spew.Sdump(data)) log.Debugf("[%s] (%d bytes) %s", n.id.HexShort(), len(encoded), spew.Sdump(data))
} }
n.conn.SetWriteDeadline(time.Now().Add(time.Second * 15)) n.conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
_, err = n.conn.WriteToUDP(encoded, addr) _, err = n.conn.WriteToUDP(encoded, addr)
return errors.Err(err) return errors.Err(err)
@ -427,9 +413,9 @@ func (n *Node) CountActiveTransactions() int {
} }
func (n *Node) startRoutingTableGrooming() { func (n *Node) startRoutingTableGrooming() {
n.stopWG.Add(1) n.stop.Add(1)
go func() { go func() {
defer n.stopWG.Done() defer n.stop.Done()
refreshTicker := time.NewTicker(tRefresh / 5) // how often to check for buckets that need to be refreshed refreshTicker := time.NewTicker(tRefresh / 5) // how often to check for buckets that need to be refreshed
for { for {
select { select {

View file

@ -5,7 +5,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/lbryio/errors.go" "github.com/lbryio/lbry.go/errors"
"github.com/lbryio/lbry.go/stopOnce" "github.com/lbryio/lbry.go/stopOnce"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"

View file

@ -11,7 +11,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/lbryio/errors.go" "github.com/lbryio/lbry.go/errors"
"github.com/lyoshenka/bencode" "github.com/lyoshenka/bencode"
) )
@ -437,9 +437,8 @@ func (rt *routingTable) UnmarshalJSON(b []byte) error {
// RoutingTableRefresh refreshes any buckets that need to be refreshed // RoutingTableRefresh refreshes any buckets that need to be refreshed
// It returns a channel that will be closed when the refresh is done // It returns a channel that will be closed when the refresh is done
func RoutingTableRefresh(n *Node, refreshInterval time.Duration, cancel <-chan struct{}) <-chan struct{} { func RoutingTableRefresh(n *Node, refreshInterval time.Duration, cancel <-chan struct{}) <-chan struct{} {
done := make(chan struct{})
var wg sync.WaitGroup var wg sync.WaitGroup
done := make(chan struct{})
for _, id := range n.rt.GetIDsForRefresh(refreshInterval) { for _, id := range n.rt.GetIDsForRefresh(refreshInterval) {
wg.Add(1) wg.Add(1)

View file

@ -7,7 +7,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/lbryio/errors.go" "github.com/lbryio/lbry.go/errors"
) )
var testingDHTIP = "127.0.0.1" var testingDHTIP = "127.0.0.1"
@ -107,7 +107,10 @@ func (t testUDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
} }
select { select {
case packet := <-t.toRead: case packet, ok := <-t.toRead:
if !ok {
return 0, nil, errors.Err("conn closed")
}
n := copy(b, packet.data) n := copy(b, packet.data)
return n, packet.addr, nil return n, packet.addr, nil
case <-timeoutCh: case <-timeoutCh:
@ -130,7 +133,7 @@ func (t *testUDPConn) SetWriteDeadline(tm time.Time) error {
} }
func (t *testUDPConn) Close() error { func (t *testUDPConn) Close() error {
t.toRead = nil close(t.toRead)
t.writes = nil t.writes = nil
return nil return nil
} }

View file

@ -16,28 +16,26 @@ type tokenManager struct {
secret []byte secret []byte
prevSecret []byte prevSecret []byte
lock *sync.RWMutex lock *sync.RWMutex
wg *sync.WaitGroup stop *stopOnce.Stopper
done *stopOnce.Stopper
} }
func (tm *tokenManager) Start(interval time.Duration) { func (tm *tokenManager) Start(interval time.Duration) {
tm.secret = make([]byte, 64) tm.secret = make([]byte, 64)
tm.prevSecret = make([]byte, 64) tm.prevSecret = make([]byte, 64)
tm.lock = &sync.RWMutex{} tm.lock = &sync.RWMutex{}
tm.wg = &sync.WaitGroup{} tm.stop = stopOnce.New()
tm.done = stopOnce.New()
tm.rotateSecret() tm.rotateSecret()
tm.wg.Add(1) tm.stop.Add(1)
go func() { go func() {
defer tm.wg.Done() defer tm.stop.Done()
tick := time.NewTicker(interval) tick := time.NewTicker(interval)
for { for {
select { select {
case <-tick.C: case <-tick.C:
tm.rotateSecret() tm.rotateSecret()
case <-tm.done.Ch(): case <-tm.stop.Ch():
return return
} }
} }
@ -45,8 +43,7 @@ func (tm *tokenManager) Start(interval time.Duration) {
} }
func (tm *tokenManager) Stop() { func (tm *tokenManager) Stop() {
tm.done.Stop() tm.stop.StopAndWait()
tm.wg.Wait()
} }
func (tm *tokenManager) Get(nodeID Bitmap, addr *net.UDPAddr) string { func (tm *tokenManager) Get(nodeID Bitmap, addr *net.UDPAddr) string {