diff --git a/client.go b/client.go index f269770..341d7f6 100644 --- a/client.go +++ b/client.go @@ -9,7 +9,8 @@ import ( ) type Client struct { - conn net.Conn + conn net.Conn + connected bool } func (c *Client) Connect(address string) error { @@ -18,13 +19,19 @@ func (c *Client) Connect(address string) error { if err != nil { return err } + c.connected = true return c.doHandshake(protocolVersion1) } func (c *Client) Close() error { + c.connected = false return c.conn.Close() } func (c *Client) SendBlob(blob []byte) error { + if !c.connected { + return fmt.Errorf("Not connected") + } + if len(blob) != BlobSize { return fmt.Errorf("Blob must be exactly " + strconv.Itoa(BlobSize) + " bytes") } @@ -37,7 +44,6 @@ func (c *Client) SendBlob(blob []byte) error { if err != nil { return err } - _, err = c.conn.Write(sendRequest) if err != nil { return err @@ -75,6 +81,10 @@ func (c *Client) SendBlob(blob []byte) error { } func (c *Client) doHandshake(version int) error { + if !c.connected { + return fmt.Errorf("Not connected") + } + handshake, err := json.Marshal(handshakeRequestResponse{Version: version}) if err != nil { return err diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..c43d50a --- /dev/null +++ b/client_test.go @@ -0,0 +1,60 @@ +package main + +import ( + "io/ioutil" + "math/rand" + "os" + "strconv" + "testing" + "time" +) + +var address = "localhost:" + strconv.Itoa(DefaultPort) +var s Server + +func TestMain(m *testing.M) { + rand.Seed(time.Now().UnixNano()) + + dir, err := ioutil.TempDir("", "reflector_client_test") + if err != nil { + panic(err) + } + defer os.RemoveAll(dir) + + s := NewServer(dir) + go s.ListenAndServe(address) + + os.Exit(m.Run()) +} + +func TestNotConnected(t *testing.T) { + c := Client{} + err := c.SendBlob([]byte{}) + if err == nil { + t.Error("client should error if it is not connected") + } +} + +func TestSmallBlob(t *testing.T) { + c := Client{} + err := c.Connect(address) + if err != nil { + t.Error(err) + } + + err = c.SendBlob([]byte{}) + if err == nil { + t.Error("client should error if blob is empty") + } + + blob := make([]byte, 1000) + _, err = rand.Read(blob) + if err != nil { + t.Error("failed to make random blob") + } + + err = c.SendBlob([]byte{}) + if err == nil { + t.Error("client should error if blob is the wrong size") + } +} diff --git a/main.go b/main.go index deb168c..db1ae62 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "flag" "log" "math/rand" + "strconv" "time" ) @@ -17,7 +18,8 @@ func main() { var err error rand.Seed(time.Now().UnixNano()) - address := "localhost:5566" + port := DefaultPort + address := "52.14.109.125:" + strconv.Itoa(port) serve := flag.Bool("server", false, "Run server") blobDir := flag.String("blobdir", "", "Where blobs will be saved to") @@ -49,16 +51,4 @@ func main() { checkErr(err) err = client.SendBlob(blob) checkErr(err) - - blob = make([]byte, 2*1024*1024) - _, err = rand.Read(blob) - checkErr(err) - err = client.SendBlob(blob) - checkErr(err) - - blob = make([]byte, 2*1024*1024) - _, err = rand.Read(blob) - checkErr(err) - err = client.SendBlob(blob) - checkErr(err) } diff --git a/server.go b/server.go index 9fe6ad0..344c634 100644 --- a/server.go +++ b/server.go @@ -87,18 +87,20 @@ func (s *Server) doError(conn net.Conn, e error) error { } func (s *Server) receiveBlob(conn net.Conn) error { - blobSize, blobHash, err := s.readBlobRequest(conn) + blobSize, blobHash, isSdBlob, err := s.readBlobRequest(conn) if err != nil { return err } blobExists := false blobPath := path.Join(s.BlobDir, blobHash) - if _, err := os.Stat(blobPath); !os.IsNotExist(err) { - blobExists = true + if !isSdBlob { // we have to say sd blobs are missing because if we say we have it, they wont try to send any content blobs + if _, err := os.Stat(blobPath); !os.IsNotExist(err) { + blobExists = true + } } - err = s.sendBlobResponse(conn, blobExists) + err = s.sendBlobResponse(conn, blobExists, isSdBlob) if err != nil { return err } @@ -116,6 +118,7 @@ func (s *Server) receiveBlob(conn net.Conn) error { receivedBlobHash := getBlobHash(blob) if blobHash != receivedBlobHash { return fmt.Errorf("Hash of received blob data does not match hash from send request") + // this can also happen if the blob size is wrong, because the server will read the wrong number of bytes from the stream } log.Println("Got blob " + blobHash[:8]) @@ -124,7 +127,7 @@ func (s *Server) receiveBlob(conn net.Conn) error { return err } - return s.sendTransferResponse(conn, true) + return s.sendTransferResponse(conn, true, isSdBlob) } func (s *Server) doHandshake(conn net.Conn) error { @@ -133,8 +136,8 @@ func (s *Server) doHandshake(conn net.Conn) error { err := dec.Decode(&handshake) if err != nil { return err - } else if handshake.Version != protocolVersion1 { - return fmt.Errorf("This server only supports protocol version 1") + } else if handshake.Version != protocolVersion1 && handshake.Version != protocolVersion2 { + return fmt.Errorf("Protocol version not supported") } resp, err := json.Marshal(handshakeRequestResponse{Version: handshake.Version}) @@ -150,36 +153,74 @@ func (s *Server) doHandshake(conn net.Conn) error { return nil } -func (s *Server) readBlobRequest(conn net.Conn) (int, string, error) { +func (s *Server) readBlobRequest(conn net.Conn) (int, string, bool, error) { var sendRequest sendBlobRequest dec := json.NewDecoder(conn) err := dec.Decode(&sendRequest) if err != nil { - return 0, "", err - } else if sendRequest.BlobSize > BlobSize { - return 0, "", fmt.Errorf("Blob size cannot be greater than " + strconv.Itoa(BlobSize) + " bytes") + return 0, "", false, err } - return sendRequest.BlobSize, sendRequest.BlobHash, nil + + if sendRequest.SdBlobHash != "" && sendRequest.BlobHash != "" { + return 0, "", false, fmt.Errorf("Invalid request") + } + + var blobHash string + var blobSize int + isSdBlob := sendRequest.SdBlobHash != "" + + if isSdBlob { + blobSize = sendRequest.SdBlobSize + blobHash = sendRequest.SdBlobHash + if blobSize > BlobSize { + return 0, "", isSdBlob, fmt.Errorf("SD blob cannot be more than " + strconv.Itoa(BlobSize) + " bytes") + } + } else { + blobSize = sendRequest.BlobSize + blobHash = sendRequest.BlobHash + if blobSize != BlobSize { + return 0, "", isSdBlob, fmt.Errorf("Blob must be exactly " + strconv.Itoa(BlobSize) + " bytes") + } + } + + return blobSize, blobHash, isSdBlob, nil } -func (s *Server) sendBlobResponse(conn net.Conn, blobExists bool) error { - sendResponse, err := json.Marshal(sendBlobResponse{SendBlob: !blobExists}) +func (s *Server) sendBlobResponse(conn net.Conn, blobExists, isSdBlob bool) error { + var response []byte + var err error + + if isSdBlob { + response, err = json.Marshal(sendSdBlobResponse{SendSdBlob: !blobExists}) + } else { + response, err = json.Marshal(sendBlobResponse{SendBlob: !blobExists}) + } if err != nil { return err } - _, err = conn.Write(sendResponse) + + _, err = conn.Write(response) if err != nil { return err } return nil } -func (s *Server) sendTransferResponse(conn net.Conn, receivedBlob bool) error { - transferResponse, err := json.Marshal(blobTransferResponse{ReceivedBlob: receivedBlob}) +func (s *Server) sendTransferResponse(conn net.Conn, receivedBlob, isSdBlob bool) error { + var response []byte + var err error + + if isSdBlob { + response, err = json.Marshal(sdBlobTransferResponse{ReceivedSdBlob: receivedBlob}) + + } else { + response, err = json.Marshal(blobTransferResponse{ReceivedBlob: receivedBlob}) + } if err != nil { return err } - _, err = conn.Write(transferResponse) + + _, err = conn.Write(response) if err != nil { return err } diff --git a/shared.go b/shared.go index 689eeab..460667a 100644 --- a/shared.go +++ b/shared.go @@ -12,7 +12,7 @@ const ( BlobSize = 2 * 1024 * 1024 protocolVersion1 = 1 - protocolVersion2 = 2 // not implemented + protocolVersion2 = 2 ) var ErrBlobExists = fmt.Errorf("Blob exists on server") @@ -26,18 +26,29 @@ type handshakeRequestResponse struct { } type sendBlobRequest struct { - BlobHash string `json:"blob_hash"` - BlobSize int `json:"blob_size"` + BlobHash string `json:"blob_hash,omitempty"` + BlobSize int `json:"blob_size,omitempty"` + SdBlobHash string `json:"sd_blob_hash,omitempty"` + SdBlobSize int `json:"sd_blob_size,omitempty"` } type sendBlobResponse struct { SendBlob bool `json:"send_blob"` } +type sendSdBlobResponse struct { + SendSdBlob bool `json:"send_sd_blob"` + NeededBlobs []string `json:"needed_blobs,omitempty"` +} + type blobTransferResponse struct { ReceivedBlob bool `json:"received_blob"` } +type sdBlobTransferResponse struct { + ReceivedSdBlob bool `json:"received_sd_blob"` +} + func getBlobHash(blob []byte) string { hashBytes := sha512.Sum384(blob) return hex.EncodeToString(hashBytes[:])