diff --git a/udp/protocol.go b/udp/protocol.go index e0f3a89..4092d93 100644 --- a/udp/protocol.go +++ b/udp/protocol.go @@ -58,6 +58,17 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte transactionID: transactionID, } + defer func() { + if writer.buf.Len() > 0 { + response = writer.buf.Bytes() + } + }() + + if action != 0 && !bytes.Equal(connID, generatedConnID) { + writer.WriteError(errBadConnectionID) + return + } + switch action { case 0: // Connect request. @@ -71,36 +82,25 @@ func (s *Server) handlePacket(packet []byte, addr *net.UDPAddr) (response []byte case 1: // Announce request. - if !bytes.Equal(connID, generatedConnID) { - writer.WriteError(errBadConnectionID) - } else { - ann, err := s.newAnnounce(packet, addr.IP) + ann, err := s.newAnnounce(packet, addr.IP) - if err == nil { - err = s.tracker.HandleAnnounce(ann, writer) - } - - handleTorrentError(err, writer) + if err == nil { + err = s.tracker.HandleAnnounce(ann, writer) } + handleTorrentError(err, writer) + case 2: // Scrape request. - if !bytes.Equal(connID, generatedConnID) { - writer.WriteError(errBadConnectionID) - } else { - scrape, err := s.newScrape(packet) + scrape, err := s.newScrape(packet) - if err == nil { - err = s.tracker.HandleScrape(scrape, writer) - } - - handleTorrentError(err, writer) + if err == nil { + err = s.tracker.HandleScrape(scrape, writer) } + + handleTorrentError(err, writer) } - if writer.buf.Len() > 0 { - response = writer.buf.Bytes() - } return } @@ -166,6 +166,7 @@ func (s *Server) newScrape(packet []byte) (*models.Scrape, error) { for len(packet) >= 20 { infohash := packet[:20] infohashes = append(infohashes, string(infohash)) + packet = packet[20:] } return &models.Scrape{