mirror of
https://github.com/LBRYFoundation/reflector.go.git
synced 2025-08-23 17:27:25 +00:00
correct peer protocol errors, add simple db store
This commit is contained in:
parent
5592f00c11
commit
8c67da1852
8 changed files with 441 additions and 257 deletions
86
db/db.go
Normal file
86
db/db.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/lbryio/errors.go"
|
||||
qtools "github.com/lbryio/query.go"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type DB interface {
|
||||
Connect(string) error
|
||||
AddBlob(string, int) error
|
||||
HasBlob(string) (bool, error)
|
||||
}
|
||||
|
||||
type SQL struct {
|
||||
conn *sql.DB
|
||||
}
|
||||
|
||||
func logQuery(query string, args ...interface{}) {
|
||||
s, err := qtools.InterpolateParams(query, args...)
|
||||
if err != nil {
|
||||
log.Errorln(err)
|
||||
} else {
|
||||
log.Debugln(s)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SQL) Connect(dsn string) error {
|
||||
var err error
|
||||
dsn += "?parseTime=1&collation=utf8mb4_unicode_ci"
|
||||
s.conn, err = sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return errors.Err(err)
|
||||
}
|
||||
|
||||
return errors.Err(s.conn.Ping())
|
||||
}
|
||||
|
||||
func (s *SQL) AddBlob(hash string, length int) error {
|
||||
if s.conn == nil {
|
||||
return errors.Err("not connected")
|
||||
}
|
||||
|
||||
if length <= 0 {
|
||||
return errors.Err("length must be positive")
|
||||
}
|
||||
|
||||
query := "INSERT IGNORE INTO blobs (hash, length) VALUES (?,?)"
|
||||
args := []interface{}{hash, length}
|
||||
|
||||
logQuery(query, args...)
|
||||
|
||||
stmt, err := s.conn.Prepare(query)
|
||||
if err != nil {
|
||||
return errors.Err(err)
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(args...)
|
||||
if err != nil {
|
||||
return errors.Err(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQL) HasBlob(hash string) (bool, error) {
|
||||
if s.conn == nil {
|
||||
return false, errors.Err("not connected")
|
||||
}
|
||||
|
||||
query := "SELECT EXISTS(SELECT 1 FROM blobs WHERE hash = ?)"
|
||||
args := []interface{}{hash}
|
||||
|
||||
logQuery(query, args...)
|
||||
|
||||
row := s.conn.QueryRow(query, args...)
|
||||
|
||||
exists := false
|
||||
err := row.Scan(&exists)
|
||||
|
||||
return exists, errors.Err(err)
|
||||
}
|
38
main.go
38
main.go
|
@ -5,10 +5,13 @@ import (
|
|||
"flag"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/lbryio/reflector.go/db"
|
||||
"github.com/lbryio/reflector.go/peer"
|
||||
"github.com/lbryio/reflector.go/reflector"
|
||||
"github.com/lbryio/reflector.go/store"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
@ -22,22 +25,38 @@ func checkErr(err error) {
|
|||
|
||||
func main() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
log.SetLevel(log.DebugLevel)
|
||||
|
||||
confFile := flag.String("conf", "config.json", "Config file")
|
||||
flag.Parse()
|
||||
|
||||
conf := loadConfig(*confFile)
|
||||
|
||||
peerAddress := "localhost:" + strconv.Itoa(peer.DefaultPort)
|
||||
server := peer.NewServer(store.NewS3BlobStore(conf.AwsID, conf.AwsSecret, conf.BucketRegion, conf.BucketName))
|
||||
log.Fatal(server.ListenAndServe(peerAddress))
|
||||
return
|
||||
db := new(db.SQL)
|
||||
err := db.Connect(conf.DBConn)
|
||||
checkErr(err)
|
||||
|
||||
//
|
||||
//address := "52.14.109.125:" + strconv.Itoa(port)
|
||||
//reflectorAddress := "localhost:" + strconv.Itoa(reflector.DefaultPort)
|
||||
//server := reflector.NewServer(store.NewS3BlobStore(conf.awsID, conf.awsSecret, conf.bucketRegion, conf.bucketName))
|
||||
//log.Fatal(server.ListenAndServe(reflectorAddress))
|
||||
s3 := store.NewS3BlobStore(conf.AwsID, conf.AwsSecret, conf.BucketRegion, conf.BucketName)
|
||||
|
||||
combo := store.NewDBBackedS3Store(s3, db)
|
||||
|
||||
serverType := ""
|
||||
if len(os.Args) > 1 {
|
||||
serverType = os.Args[1]
|
||||
}
|
||||
|
||||
switch serverType {
|
||||
case "reflector":
|
||||
reflectorAddress := "localhost:" + strconv.Itoa(reflector.DefaultPort)
|
||||
server := reflector.NewServer(combo)
|
||||
log.Fatal(server.ListenAndServe(reflectorAddress))
|
||||
case "peer":
|
||||
peerAddress := "localhost:" + strconv.Itoa(peer.DefaultPort)
|
||||
server := peer.NewServer(combo)
|
||||
log.Fatal(server.ListenAndServe(peerAddress))
|
||||
default:
|
||||
log.Fatal("invalid server type")
|
||||
}
|
||||
|
||||
//
|
||||
//var err error
|
||||
|
@ -66,6 +85,7 @@ type config struct {
|
|||
AwsSecret string `json:"aws_secret"`
|
||||
BucketRegion string `json:"bucket_region"`
|
||||
BucketName string `json:"bucket_name"`
|
||||
DBConn string `json:"db_conn"`
|
||||
}
|
||||
|
||||
func loadConfig(path string) config {
|
||||
|
|
172
peer/server.go
172
peer/server.go
|
@ -1,12 +1,16 @@
|
|||
package peer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lbryio/reflector.go/store"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
@ -39,43 +43,66 @@ func (s *Server) ListenAndServe(address string) error {
|
|||
if err != nil {
|
||||
log.Error(err)
|
||||
} else {
|
||||
go s.handleConn(conn)
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConn(conn net.Conn) {
|
||||
// TODO: connection should time out eventually
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
err := s.doAvailabilityRequest(conn)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = s.doPaymentRateNegotiation(conn)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
timeoutDuration := 5 * time.Second
|
||||
|
||||
for {
|
||||
err = s.doBlobRequest(conn)
|
||||
var request []byte
|
||||
var response []byte
|
||||
var err error
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(timeoutDuration))
|
||||
request, err = readNextRequest(conn)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Error(err)
|
||||
log.Errorln(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if strings.Contains(string(request), `"requested_blobs"`) {
|
||||
log.Debugln("received availability request")
|
||||
response, err = s.handleAvailabilityRequest(request)
|
||||
} else if strings.Contains(string(request), `"blob_data_payment_rate"`) {
|
||||
log.Debugln("received rate negotiation request")
|
||||
response, err = s.handlePaymentRateNegotiation(request)
|
||||
} else if strings.Contains(string(request), `"requested_blob"`) {
|
||||
log.Debugln("received blob request")
|
||||
response, err = s.handleBlobRequest(request)
|
||||
} else {
|
||||
log.Errorln("invalid request")
|
||||
spew.Dump(request)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
n, err := conn.Write(response)
|
||||
if err != nil {
|
||||
log.Errorln(err)
|
||||
return
|
||||
} else if n != len(response) {
|
||||
log.Errorln(io.ErrShortWrite)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) doAvailabilityRequest(conn net.Conn) error {
|
||||
func (s *Server) handleAvailabilityRequest(data []byte) ([]byte, error) {
|
||||
var request availabilityRequest
|
||||
err := json.NewDecoder(conn).Decode(&request)
|
||||
err := json.Unmarshal(data, &request)
|
||||
if err != nil {
|
||||
return err
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
address := "bJxKvpD96kaJLriqVajZ7SaQTsWWyrGQct"
|
||||
|
@ -83,31 +110,21 @@ func (s *Server) doAvailabilityRequest(conn net.Conn) error {
|
|||
for _, blobHash := range request.RequestedBlobs {
|
||||
exists, err := s.store.Has(blobHash)
|
||||
if err != nil {
|
||||
return err
|
||||
return []byte{}, err
|
||||
}
|
||||
if exists {
|
||||
availableBlobs = append(availableBlobs, blobHash)
|
||||
}
|
||||
}
|
||||
|
||||
response, err := json.Marshal(availabilityResponse{LbrycrdAddress: address, AvailableBlobs: availableBlobs})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = conn.Write(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return json.Marshal(availabilityResponse{LbrycrdAddress: address, AvailableBlobs: availableBlobs})
|
||||
}
|
||||
|
||||
func (s *Server) doPaymentRateNegotiation(conn net.Conn) error {
|
||||
func (s *Server) handlePaymentRateNegotiation(data []byte) ([]byte, error) {
|
||||
var request paymentRateRequest
|
||||
err := json.NewDecoder(conn).Decode(&request)
|
||||
err := json.Unmarshal(data, &request)
|
||||
if err != nil {
|
||||
return err
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
offerReply := paymentRateAccepted
|
||||
|
@ -115,31 +132,21 @@ func (s *Server) doPaymentRateNegotiation(conn net.Conn) error {
|
|||
offerReply = paymentRateTooLow
|
||||
}
|
||||
|
||||
response, err := json.Marshal(paymentRateResponse{BlobDataPaymentRate: offerReply})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = conn.Write(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return json.Marshal(paymentRateResponse{BlobDataPaymentRate: offerReply})
|
||||
}
|
||||
|
||||
func (s *Server) doBlobRequest(conn net.Conn) error {
|
||||
func (s *Server) handleBlobRequest(data []byte) ([]byte, error) {
|
||||
var request blobRequest
|
||||
err := json.NewDecoder(conn).Decode(&request)
|
||||
err := json.Unmarshal(data, &request)
|
||||
if err != nil {
|
||||
return err
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
log.Println("Sending blob " + request.RequestedBlob[:8])
|
||||
|
||||
blob, err := s.store.Get(request.RequestedBlob)
|
||||
if err != nil {
|
||||
return err
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
response, err := json.Marshal(blobResponse{IncomingBlob: incomingBlob{
|
||||
|
@ -147,40 +154,63 @@ func (s *Server) doBlobRequest(conn net.Conn) error {
|
|||
Length: len(blob),
|
||||
}})
|
||||
if err != nil {
|
||||
return err
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
_, err = conn.Write(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = conn.Write(blob)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return append(response, blob...), nil
|
||||
}
|
||||
|
||||
func readAll(conn net.Conn) {
|
||||
buf := make([]byte, 0, 4096) // big buffer
|
||||
tmp := make([]byte, 256) // using small tmo buffer for demonstrating
|
||||
func readNextRequest(conn net.Conn) ([]byte, error) {
|
||||
request := make([]byte, 0)
|
||||
eof := false
|
||||
buf := bufio.NewReader(conn)
|
||||
|
||||
for {
|
||||
n, err := conn.Read(tmp)
|
||||
chunk, err := buf.ReadBytes('}')
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Println("read error:", err)
|
||||
log.Errorln("read error:", err)
|
||||
return request, err
|
||||
}
|
||||
eof = true
|
||||
}
|
||||
|
||||
//log.Debugln("got", len(chunk), "bytes.")
|
||||
//spew.Dump(chunk)
|
||||
|
||||
if len(chunk) > 0 {
|
||||
request = append(request, chunk...)
|
||||
|
||||
if len(request) > maxRequestSize {
|
||||
return request, errRequestTooLarge
|
||||
}
|
||||
|
||||
// yes, this is how the peer protocol knows when the request finishes
|
||||
if isValidJSON(request) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if eof {
|
||||
break
|
||||
}
|
||||
log.Println("got", n, "bytes.")
|
||||
buf = append(buf, tmp[:n]...)
|
||||
}
|
||||
log.Println("total size:", len(buf))
|
||||
if len(buf) > 0 {
|
||||
log.Println(string(buf))
|
||||
|
||||
//log.Debugln("total size:", len(request))
|
||||
//if len(request) > 0 {
|
||||
// spew.Dump(request)
|
||||
//}
|
||||
|
||||
if len(request) == 0 && eof {
|
||||
return []byte{}, io.EOF
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func isValidJSON(b []byte) bool {
|
||||
var r json.RawMessage
|
||||
return json.Unmarshal(b, &r) == nil
|
||||
}
|
||||
|
||||
func getBlobHash(blob []byte) string {
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
package peer
|
||||
|
||||
import "github.com/lbryio/errors.go"
|
||||
|
||||
const maxRequestSize = 64 * (2 ^ 10) // 64kb
|
||||
|
||||
var errRequestTooLarge = errors.Base("request is too large")
|
||||
|
||||
type availabilityRequest struct {
|
||||
LbrycrdAddress bool `json:"lbrycrd_address"`
|
||||
RequestedBlobs []string `json:"requested_blobs"`
|
||||
|
|
29
store/dbbacked.go
Normal file
29
store/dbbacked.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package store
|
||||
|
||||
import "github.com/lbryio/reflector.go/db"
|
||||
|
||||
type DBBackedS3Store struct {
|
||||
s3 *S3BlobStore
|
||||
db db.DB
|
||||
}
|
||||
|
||||
func NewDBBackedS3Store(s3 *S3BlobStore, db db.DB) *DBBackedS3Store {
|
||||
return &DBBackedS3Store{s3: s3, db: db}
|
||||
}
|
||||
|
||||
func (d *DBBackedS3Store) Has(hash string) (bool, error) {
|
||||
return d.db.HasBlob(hash)
|
||||
}
|
||||
|
||||
func (d *DBBackedS3Store) Get(hash string) ([]byte, error) {
|
||||
return d.s3.Get(hash)
|
||||
}
|
||||
|
||||
func (d *DBBackedS3Store) Put(hash string, blob []byte) error {
|
||||
err := d.s3.Put(hash, blob)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return d.db.AddBlob(hash, len(blob))
|
||||
}
|
83
store/file.go
Normal file
83
store/file.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/lbryio/errors.go"
|
||||
)
|
||||
|
||||
type FileBlobStore struct {
|
||||
dir string
|
||||
|
||||
initialized bool
|
||||
}
|
||||
|
||||
func NewFileBlobStore(dir string) *FileBlobStore {
|
||||
return &FileBlobStore{dir: dir}
|
||||
}
|
||||
|
||||
func (f *FileBlobStore) path(hash string) string {
|
||||
return path.Join(f.dir, hash)
|
||||
}
|
||||
|
||||
func (f *FileBlobStore) initOnce() error {
|
||||
if f.initialized {
|
||||
return nil
|
||||
}
|
||||
defer func() { f.initialized = true }()
|
||||
|
||||
if stat, err := os.Stat(f.dir); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
err2 := os.Mkdir(f.dir, 0755)
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
} else if !stat.IsDir() {
|
||||
return errors.Err("blob dir exists but is not a dir")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FileBlobStore) Has(hash string) (bool, error) {
|
||||
err := f.initOnce()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, err = os.Stat(f.path(hash))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (f *FileBlobStore) Get(hash string) ([]byte, error) {
|
||||
err := f.initOnce()
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
file, err := os.Open(f.path(hash))
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
return ioutil.ReadAll(file)
|
||||
}
|
||||
|
||||
func (f *FileBlobStore) Put(hash string, blob []byte) error {
|
||||
err := f.initOnce()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ioutil.WriteFile(f.path(hash), blob, 0644)
|
||||
}
|
107
store/s3.go
Normal file
107
store/s3.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type S3BlobStore struct {
|
||||
awsID string
|
||||
awsSecret string
|
||||
region string
|
||||
bucket string
|
||||
|
||||
session *session.Session
|
||||
}
|
||||
|
||||
func NewS3BlobStore(awsID, awsSecret, region, bucket string) *S3BlobStore {
|
||||
return &S3BlobStore{
|
||||
awsID: awsID,
|
||||
awsSecret: awsSecret,
|
||||
region: region,
|
||||
bucket: bucket,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3BlobStore) initOnce() error {
|
||||
if s.session != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Credentials: credentials.NewStaticCredentials(s.awsID, s.awsSecret, ""),
|
||||
Region: aws.String(s.region),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.session = sess
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3BlobStore) Has(hash string) (bool, error) {
|
||||
err := s.initOnce()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, err = s3.New(s.session).HeadObject(&s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(hash),
|
||||
})
|
||||
if err != nil {
|
||||
if reqFail, ok := err.(s3.RequestFailure); ok && reqFail.StatusCode() == http.StatusNotFound {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *S3BlobStore) Get(hash string) ([]byte, error) {
|
||||
err := s.initOnce()
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
buf := &aws.WriteAtBuffer{}
|
||||
_, err = s3manager.NewDownloader(s.session).Download(buf, &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(hash),
|
||||
})
|
||||
if err != nil {
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (s *S3BlobStore) Put(hash string, blob []byte) error {
|
||||
err := s.initOnce()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("Uploading %s to S3", hash[:8])
|
||||
defer func(t time.Time) {
|
||||
log.Debugf("Upload took %s", time.Since(t).String())
|
||||
}(time.Now())
|
||||
|
||||
_, err = s3manager.NewUploader(s.session).Upload(&s3manager.UploadInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(hash),
|
||||
Body: bytes.NewBuffer(blob),
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
177
store/store.go
177
store/store.go
|
@ -1,184 +1,7 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/lbryio/errors.go"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
)
|
||||
|
||||
type BlobStore interface {
|
||||
Has(string) (bool, error)
|
||||
Get(string) ([]byte, error)
|
||||
Put(string, []byte) error
|
||||
}
|
||||
|
||||
type FileBlobStore struct {
|
||||
dir string
|
||||
|
||||
initialized bool
|
||||
}
|
||||
|
||||
func NewFileBlobStore(dir string) *FileBlobStore {
|
||||
return &FileBlobStore{dir: dir}
|
||||
}
|
||||
|
||||
func (f *FileBlobStore) path(hash string) string {
|
||||
return path.Join(f.dir, hash)
|
||||
}
|
||||
|
||||
func (f *FileBlobStore) initOnce() error {
|
||||
if f.initialized {
|
||||
return nil
|
||||
}
|
||||
defer func() { f.initialized = true }()
|
||||
|
||||
if stat, err := os.Stat(f.dir); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
err2 := os.Mkdir(f.dir, 0755)
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
} else if !stat.IsDir() {
|
||||
return errors.Err("blob dir exists but is not a dir")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FileBlobStore) Has(hash string) (bool, error) {
|
||||
err := f.initOnce()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, err = os.Stat(f.path(hash))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (f *FileBlobStore) Get(hash string) ([]byte, error) {
|
||||
err := f.initOnce()
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
file, err := os.Open(f.path(hash))
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
return ioutil.ReadAll(file)
|
||||
}
|
||||
|
||||
func (f *FileBlobStore) Put(hash string, blob []byte) error {
|
||||
err := f.initOnce()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ioutil.WriteFile(f.path(hash), blob, 0644)
|
||||
}
|
||||
|
||||
type S3BlobStore struct {
|
||||
awsID string
|
||||
awsSecret string
|
||||
region string
|
||||
bucket string
|
||||
|
||||
session *session.Session
|
||||
}
|
||||
|
||||
func NewS3BlobStore(awsID, awsSecret, region, bucket string) *S3BlobStore {
|
||||
return &S3BlobStore{
|
||||
awsID: awsID,
|
||||
awsSecret: awsSecret,
|
||||
region: region,
|
||||
bucket: bucket,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3BlobStore) initOnce() error {
|
||||
if s.session != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Credentials: credentials.NewStaticCredentials(s.awsID, s.awsSecret, ""),
|
||||
Region: aws.String(s.region),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.session = sess
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3BlobStore) Has(hash string) (bool, error) {
|
||||
err := s.initOnce()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, err = s3.New(s.session).HeadObject(&s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(hash),
|
||||
})
|
||||
if err != nil {
|
||||
if reqFail, ok := err.(s3.RequestFailure); ok && reqFail.StatusCode() == http.StatusNotFound {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *S3BlobStore) Get(hash string) ([]byte, error) {
|
||||
err := s.initOnce()
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
buf := &aws.WriteAtBuffer{}
|
||||
_, err = s3manager.NewDownloader(s.session).Download(buf, &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(hash),
|
||||
})
|
||||
if err != nil {
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (s *S3BlobStore) Put(hash string, blob []byte) error {
|
||||
err := s.initOnce()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s3manager.NewUploader(s.session).Upload(&s3manager.UploadInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(hash),
|
||||
Body: bytes.NewBuffer(blob),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue