diff --git a/dht/message.go b/dht/message.go index e78ddd6..d10f1d5 100644 --- a/dht/message.go +++ b/dht/message.go @@ -41,6 +41,7 @@ const ( headerArgsField = "4" contactsField = "contacts" tokenField = "token" + protocolVersionField = "protocolVersion" ) // Message is an extension of the bencode marshalling interface for serialized message passing. @@ -83,11 +84,12 @@ func newMessageID() messageID { // Request represents the structured request from one node to another. type Request struct { - ID messageID - NodeID bits.Bitmap - Method string - Arg *bits.Bitmap - StoreArgs *storeArgs + ID messageID + NodeID bits.Bitmap + Method string + Arg *bits.Bitmap + StoreArgs *storeArgs + ProtocolVersion int } // MarshalBencode returns the serialized byte slice representation of the request @@ -133,17 +135,47 @@ func (r *Request) UnmarshalBencode(b []byte) error { return errors.Prefix("request unmarshal", err) } } else if len(raw.Args) > 2 { // 2 because an empty list is `le` - tmp := []bits.Bitmap{} - err = bencode.DecodeBytes(raw.Args, &tmp) + r.Arg, r.ProtocolVersion, err = processArgsAndProtoVersion(raw.Args) if err != nil { return errors.Prefix("request unmarshal", err) } - r.Arg = &tmp[0] } return nil } +func processArgsAndProtoVersion(raw bencode.RawMessage) (arg *bits.Bitmap, version int, err error) { + var args []bencode.RawMessage + err = bencode.DecodeBytes(raw, &args) + if err != nil { + return nil, 0, err + } + + if len(args) == 0 { + return nil, 0, nil + } + + var extras map[string]int + err = bencode.DecodeBytes(args[len(args)-1], &extras) + if err == nil { + if v, exists := extras[protocolVersionField]; exists { + version = v + args = args[:len(args)-1] + } + } + + if len(args) > 0 { + var b bits.Bitmap + err = bencode.DecodeBytes(args[0], &b) + if err != nil { + return nil, 0, err + } + arg = &b + } + + return arg, version, nil +} + func (r Request) argsDebug() string { if r.StoreArgs != nil { return r.StoreArgs.BlobHash.HexShort() + ", " + r.StoreArgs.Value.LbryID.HexShort() + ":" + strconv.Itoa(r.StoreArgs.Value.Port) @@ -231,12 +263,13 @@ func (s *storeArgs) UnmarshalBencode(b []byte) error { // Response represents the structured response one node returns to another. type Response struct { - ID messageID - NodeID bits.Bitmap - Data string - Contacts []Contact - FindValueKey string - Token string + ID messageID + NodeID bits.Bitmap + Data string + Contacts []Contact + FindValueKey string + Token string + ProtocolVersion int } func (r Response) argsDebug() string { @@ -251,7 +284,7 @@ func (r Response) argsDebug() string { str += "|" for _, c := range r.Contacts { - str += c.Addr().String() + ":" + c.ID.HexShort() + "," + str += c.String() + "," } str = strings.TrimRight(str, ",") + "|" @@ -344,7 +377,15 @@ func (r *Response) UnmarshalBencode(b []byte) error { if err != nil { return err } - delete(rawData, tokenField) // it doesnt mess up findValue key finding below + delete(rawData, tokenField) // so it doesnt mess up findValue key finding below + } + + if protocolVersion, ok := rawData[protocolVersionField]; ok { + err = bencode.DecodeBytes(protocolVersion, &r.ProtocolVersion) + if err != nil { + return err + } + delete(rawData, protocolVersionField) // so it doesnt mess up findValue key finding below } if contacts, ok := rawData[contactsField]; ok { diff --git a/dht/message_test.go b/dht/message_test.go index c42cdef..71812eb 100644 --- a/dht/message_test.go +++ b/dht/message_test.go @@ -167,6 +167,40 @@ func TestDecodeFindNodeResponseWithNoNodes(t *testing.T) { } } +func TestDecodeRequestWithProtocolVersion(t *testing.T) { + raw, err := hex.DecodeString("6469306569306569316532303a65e1b9afce87c44abc40b4bb466890d2797f0dd269326534383a1baf3dbba8746c7a739f35465e268ace823622c5e8ec1dd7d5d27af5795cfbc54f22a32fbd05d420f241266b3dd16831693365343a70696e676934656c6431353a70726f746f636f6c56657273696f6e693165656565") + if err != nil { + t.Fatal(err) + } + + req := Request{} + err = bencode.DecodeBytes(raw, &req) + if err != nil { + t.Fatal(err) + } + + if req.ProtocolVersion != 1 { + t.Error("protocol version was not detected correctly") + } +} + +func TestDecodeResponseWithProtocolVersion(t *testing.T) { + raw, err := hex.DecodeString("6469306569316569316532303a2b96f5de8be1e86dc2332a50eb313c97848064db69326534383a8b8eb692658ea3d7e7828e80a3133d524c6f82aaff370efa759d5b87821035a32de06724cb099f01a819695f829dff7f6933656431353a70726f746f636f6c56657273696f6e693165353a746f6b656e34383af2368fe4fd06ede6631ad8b153f3b5d8db724f8d520f4291e992e206dd02a216fb7dfd9b81686f11b626e3840df65fb434383a89c5c3f9794b0b24a03406e3b74361edb9ae70828e4c133512fc75db0a2d312673cdd4e30eed37892a46692d2fe439f36c35343a12dd65af0d058cd7d10d122fbe2eb5ae31062b7480011be588f20cfe6807b1939c42eea639059fa6365bfccb56ef8c9e574f49ba035c656565") + if err != nil { + t.Fatal(err) + } + + res := Response{} + err = bencode.DecodeBytes(raw, &res) + if err != nil { + t.Fatal(err) + } + + if res.ProtocolVersion != 1 { + t.Error("protocol version was not detected correctly") + } +} + func compareResponses(t *testing.T, res, res2 Response) { if res.ID != res2.ID { t.Errorf("expected ID %s, got %s", res.ID, res2.ID)