diff --git a/bittorrent/client_id.go b/bittorrent/client_id.go index b50be80..840fd75 100644 --- a/bittorrent/client_id.go +++ b/bittorrent/client_id.go @@ -2,21 +2,21 @@ package bittorrent // ClientID represents the part of a PeerID that identifies a Peer's client // software. -type ClientID string +type ClientID [6]byte // NewClientID parses a ClientID from a PeerID. -func NewClientID(peerID string) ClientID { - var clientID string - length := len(peerID) +func NewClientID(pid PeerID) ClientID { + var cid ClientID + length := len(pid) if length >= 6 { - if peerID[0] == '-' { + if pid[0] == '-' { if length >= 7 { - clientID = peerID[1:7] + copy(cid[:], pid[1:7]) } } else { - clientID = peerID[:6] + copy(cid[:], pid[:6]) } } - return ClientID(clientID) + return cid } diff --git a/bittorrent/client_id_test.go b/bittorrent/client_id_test.go index 126f701..ce760fa 100644 --- a/bittorrent/client_id_test.go +++ b/bittorrent/client_id_test.go @@ -1,6 +1,9 @@ package bittorrent -import "testing" +import ( + "bytes" + "testing" +) func TestClientID(t *testing.T) { var clientTable = []struct{ peerID, clientID string }{ @@ -38,17 +41,12 @@ func TestClientID(t *testing.T) { {"Q1-10-0-Yoiumn39BDfO", "Q1-10-"}, // Queen Bee Alt {"346------SDFknl33408", "346---"}, // TorreTopia {"QVOD0054ABFFEDCCDEDB", "QVOD00"}, // Qvod - - {"", ""}, - {"-", ""}, - {"12345", ""}, - {"-12345", ""}, - {"123456", "123456"}, - {"-123456", "123456"}, } for _, tt := range clientTable { - if parsedID := NewClientID(tt.peerID); parsedID != ClientID(tt.clientID) { + clientID := ClientID([]byte(tt.clientID)) + parsedID := NewClientID(PeerIDFromBytes([]byte(tt.peerID))) + if !bytes.Equal([]byte(parsedID), []byte(clientID)) { t.Error("Incorrectly parsed peer ID", tt.peerID, "as", parsedID) } } diff --git a/middleware/clientwhitelist/clientwhitelist.go b/middleware/clientwhitelist/clientwhitelist.go index f7aefd4..3443e05 100644 --- a/middleware/clientwhitelist/clientwhitelist.go +++ b/middleware/clientwhitelist/clientwhitelist.go @@ -4,6 +4,7 @@ package clientwhitelist import ( "context" + "errors" "github.com/chihaya/chihaya/bittorrent" "github.com/chihaya/chihaya/middleware" @@ -17,16 +18,22 @@ type hook struct { approved map[bittorrent.ClientID]struct{} } -func NewHook(approved []string) middleware.Hook { +func NewHook(approved []string) (middleware.Hook, error) { h := &hook{ approved: make(map[bittorrent.ClientID]struct{}), } - for _, clientID := range approved { - h.approved[bittorrent.NewClientID(clientID)] = struct{}{} + for _, cidString := range approved { + cidBytes := []byte(cidString) + if len(cidBytes) != 6 { + return nil, errors.New("clientID " + cidString + " must be 6 bytes") + } + var cid bittorrent.ClientID + copy(cid[:], cidBytes) + h.approved[cid] = struct{}{} } - return h + return h, nil } func (h *hook) HandleAnnounce(ctx context.Context, req *bittorrent.AnnounceRequest, resp *bittorrent.AnnounceResponse) error {