Simpler interface

This commit is contained in:
Justin Li 2014-07-16 13:03:59 -04:00
parent 3ceaf72034
commit 2a4d4c5aef
5 changed files with 13 additions and 39 deletions

View file

@ -69,6 +69,7 @@ type Pool interface {
// Conn represents a connection to the data store that can be used // Conn represents a connection to the data store that can be used
// to make reads/writes. // to make reads/writes.
type Conn interface { type Conn interface {
// Torrent interactions
FindTorrent(infohash string) (*models.Torrent, error) FindTorrent(infohash string) (*models.Torrent, error)
PutTorrent(t *models.Torrent) error PutTorrent(t *models.Torrent) error
DeleteTorrent(infohash string) error DeleteTorrent(infohash string) error
@ -78,14 +79,13 @@ type Conn interface {
DeleteLeecher(infohash, peerkey string) error DeleteLeecher(infohash, peerkey string) error
PutSeeder(infohash string, p *models.Peer) error PutSeeder(infohash string, p *models.Peer) error
DeleteSeeder(infohash, peerkey string) error DeleteSeeder(infohash, peerkey string) error
}
// PrivateConn represents a connection that can service queries for private trackers. // User interactions
type PrivateConn interface {
FindUser(passkey string) (*models.User, error) FindUser(passkey string) (*models.User, error)
PutUser(u *models.User) error PutUser(u *models.User) error
DeleteUser(passkey string) error DeleteUser(passkey string) error
// Whitelist interactions
FindClient(clientID string) error FindClient(clientID string) error
PutClient(clientID string) error PutClient(clientID string) error
DeleteClient(clientID string) error DeleteClient(clientID string) error

View file

@ -30,7 +30,7 @@ func (t *Tracker) ServeAnnounce(w http.ResponseWriter, r *http.Request, p httpro
} }
if t.cfg.Whitelist { if t.cfg.Whitelist {
err = conn.(tracker.PrivateConn).FindClient(ann.ClientID()) err = conn.FindClient(ann.ClientID())
if err == tracker.ErrClientUnapproved { if err == tracker.ErrClientUnapproved {
fail(w, r, err) fail(w, r, err)
return http.StatusOK, nil return http.StatusOK, nil
@ -41,7 +41,7 @@ func (t *Tracker) ServeAnnounce(w http.ResponseWriter, r *http.Request, p httpro
var user *models.User var user *models.User
if t.cfg.Private { if t.cfg.Private {
user, err = conn.(tracker.PrivateConn).FindUser(ann.Passkey) user, err = conn.FindUser(ann.Passkey)
if err == tracker.ErrUserDNE { if err == tracker.ErrUserDNE {
fail(w, r, err) fail(w, r, err)
return http.StatusOK, nil return http.StatusOK, nil

View file

@ -9,7 +9,6 @@ import (
"github.com/chihaya/bencode" "github.com/chihaya/bencode"
"github.com/chihaya/chihaya/config" "github.com/chihaya/chihaya/config"
"github.com/chihaya/chihaya/drivers/tracker"
"github.com/chihaya/chihaya/models" "github.com/chihaya/chihaya/models"
) )
@ -76,7 +75,7 @@ func loadTestData(tkr *Tracker) error {
} }
for i, passkey := range users { for i, passkey := range users {
err = conn.(tracker.PrivateConn).PutUser(&models.User{ err = conn.PutUser(&models.User{
ID: uint64(i + 1), ID: uint64(i + 1),
Passkey: passkey, Passkey: passkey,
}) })
@ -86,7 +85,7 @@ func loadTestData(tkr *Tracker) error {
} }
} }
err = conn.(tracker.PrivateConn).PutClient("TR2820") err = conn.PutClient("TR2820")
if err != nil { if err != nil {
return err return err
} }

View file

@ -102,16 +102,11 @@ func (t *Tracker) delTorrent(w http.ResponseWriter, r *http.Request, p httproute
} }
func (t *Tracker) getUser(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) { func (t *Tracker) getUser(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) {
base, err := t.pool.Get() conn, err := t.pool.Get()
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
conn, private := base.(tracker.PrivateConn)
if !private {
return http.StatusNotFound, nil
}
user, err := conn.FindUser(p.ByName("passkey")) user, err := conn.FindUser(p.ByName("passkey"))
if err == tracker.ErrUserDNE { if err == tracker.ErrUserDNE {
return http.StatusNotFound, err return http.StatusNotFound, err
@ -141,16 +136,11 @@ func (t *Tracker) putUser(w http.ResponseWriter, r *http.Request, p httprouter.P
return http.StatusBadRequest, err return http.StatusBadRequest, err
} }
base, err := t.pool.Get() conn, err := t.pool.Get()
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
conn, private := base.(tracker.PrivateConn)
if !private {
return http.StatusNotFound, nil
}
err = conn.PutUser(&user) err = conn.PutUser(&user)
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
@ -160,16 +150,11 @@ func (t *Tracker) putUser(w http.ResponseWriter, r *http.Request, p httprouter.P
} }
func (t *Tracker) delUser(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) { func (t *Tracker) delUser(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) {
base, err := t.pool.Get() conn, err := t.pool.Get()
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
conn, private := base.(tracker.PrivateConn)
if !private {
return http.StatusNotFound, nil
}
err = conn.DeleteUser(p.ByName("passkey")) err = conn.DeleteUser(p.ByName("passkey"))
if err == tracker.ErrUserDNE { if err == tracker.ErrUserDNE {
return http.StatusNotFound, err return http.StatusNotFound, err
@ -181,16 +166,11 @@ func (t *Tracker) delUser(w http.ResponseWriter, r *http.Request, p httprouter.P
} }
func (t *Tracker) putClient(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) { func (t *Tracker) putClient(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) {
base, err := t.pool.Get() conn, err := t.pool.Get()
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
conn, private := base.(tracker.PrivateConn)
if !private {
return http.StatusNotFound, nil
}
err = conn.PutClient(p.ByName("clientID")) err = conn.PutClient(p.ByName("clientID"))
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
@ -200,16 +180,11 @@ func (t *Tracker) putClient(w http.ResponseWriter, r *http.Request, p httprouter
} }
func (t *Tracker) delClient(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) { func (t *Tracker) delClient(w http.ResponseWriter, r *http.Request, p httprouter.Params) (int, error) {
base, err := t.pool.Get() conn, err := t.pool.Get()
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err
} }
conn, private := base.(tracker.PrivateConn)
if !private {
return http.StatusNotFound, nil
}
err = conn.DeleteClient(p.ByName("clientID")) err = conn.DeleteClient(p.ByName("clientID"))
if err != nil { if err != nil {
return http.StatusInternalServerError, err return http.StatusInternalServerError, err

View file

@ -29,7 +29,7 @@ func (t *Tracker) ServeScrape(w http.ResponseWriter, r *http.Request, p httprout
} }
if t.cfg.Private { if t.cfg.Private {
_, err = conn.(tracker.PrivateConn).FindUser(scrape.Passkey) _, err = conn.FindUser(scrape.Passkey)
if err == tracker.ErrUserDNE { if err == tracker.ErrUserDNE {
fail(w, r, err) fail(w, r, err)
return http.StatusOK, nil return http.StatusOK, nil