diff --git a/server.go b/server.go index 9b5554d..9fe6ad0 100644 --- a/server.go +++ b/server.go @@ -87,54 +87,44 @@ func (s *Server) doError(conn net.Conn, e error) error { } func (s *Server) receiveBlob(conn net.Conn) error { - var sendRequest sendBlobRequest - dec := json.NewDecoder(conn) - err := dec.Decode(&sendRequest) - if err != nil { - return err - } else if sendRequest.BlobSize > BlobSize { - return fmt.Errorf("Blob size cannot be greater than " + strconv.Itoa(BlobSize) + " bytes") - } - - // check if blob exists - haveBlob := false - sendResponse, err := json.Marshal(sendBlobResponse{SendBlob: !haveBlob}) + blobSize, blobHash, err := s.readBlobRequest(conn) if err != nil { return err } - _, err = conn.Write(sendResponse) + blobExists := false + blobPath := path.Join(s.BlobDir, blobHash) + if _, err := os.Stat(blobPath); !os.IsNotExist(err) { + blobExists = true + } + + err = s.sendBlobResponse(conn, blobExists) if err != nil { return err } - blob := make([]byte, sendRequest.BlobSize) + if blobExists { + return nil + } + + blob := make([]byte, blobSize) _, err = io.ReadFull(bufio.NewReader(conn), blob) if err != nil { return err } - blobHash := getBlobHash(blob) - if sendRequest.BlobHash != blobHash { - return fmt.Errorf("Hash of received blob does not match hash from send request") + receivedBlobHash := getBlobHash(blob) + if blobHash != receivedBlobHash { + return fmt.Errorf("Hash of received blob data does not match hash from send request") } log.Println("Got blob " + blobHash[:8]) - err = ioutil.WriteFile(path.Join(s.BlobDir, blobHash), blob, 0644) + err = ioutil.WriteFile(blobPath, blob, 0644) if err != nil { return err } - transferResponse, err := json.Marshal(blobTransferResponse{ReceivedBlob: true}) - if err != nil { - return err - } - - _, err = conn.Write(transferResponse) - if err != nil { - return err - } - return nil + return s.sendTransferResponse(conn, true) } func (s *Server) doHandshake(conn net.Conn) error { @@ -160,6 +150,42 @@ func (s *Server) doHandshake(conn net.Conn) error { return nil } +func (s *Server) readBlobRequest(conn net.Conn) (int, string, error) { + var sendRequest sendBlobRequest + dec := json.NewDecoder(conn) + err := dec.Decode(&sendRequest) + if err != nil { + return 0, "", err + } else if sendRequest.BlobSize > BlobSize { + return 0, "", fmt.Errorf("Blob size cannot be greater than " + strconv.Itoa(BlobSize) + " bytes") + } + return sendRequest.BlobSize, sendRequest.BlobHash, nil +} + +func (s *Server) sendBlobResponse(conn net.Conn, blobExists bool) error { + sendResponse, err := json.Marshal(sendBlobResponse{SendBlob: !blobExists}) + if err != nil { + return err + } + _, err = conn.Write(sendResponse) + if err != nil { + return err + } + return nil +} + +func (s *Server) sendTransferResponse(conn net.Conn, receivedBlob bool) error { + transferResponse, err := json.Marshal(blobTransferResponse{ReceivedBlob: receivedBlob}) + if err != nil { + return err + } + _, err = conn.Write(transferResponse) + if err != nil { + return err + } + return nil +} + func (s *Server) ensureBlobDirExists() error { if stat, err := os.Stat(s.BlobDir); err != nil { if os.IsNotExist(err) {