diff --git a/rpcserver.go b/rpcserver.go index 851e1f6c..f69ecb78 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -272,419 +272,6 @@ func newGbtWorkState(timeSource blockchain.MedianTimeSource) *gbtWorkState { } } -// rpcServer holds the items the rpc server may need to access (config, -// shutdown, main server, etc.) -type rpcServer struct { - started int32 - shutdown int32 - server *server - authsha [fastsha256.Size]byte - ntfnMgr *wsNotificationManager - numClients int32 - statusLines map[int]string - statusLock sync.RWMutex - wg sync.WaitGroup - listeners []net.Listener - workState *workState - gbtWorkState *gbtWorkState - quit chan int -} - -// Start is used by server.go to start the rpc listener. -func (s *rpcServer) Start() { - if atomic.AddInt32(&s.started, 1) != 1 { - return - } - - rpcsLog.Trace("Starting RPC server") - rpcServeMux := http.NewServeMux() - httpServer := &http.Server{ - Handler: rpcServeMux, - - // Timeout connections which don't complete the initial - // handshake within the allowed timeframe. - ReadTimeout: time.Second * rpcAuthTimeoutSeconds, - } - rpcServeMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Connection", "close") - w.Header().Set("Content-Type", "application/json") - r.Close = true - - // Limit the number of connections to max allowed. - if s.limitConnections(w, r.RemoteAddr) { - return - } - - // Keep track of the number of connected clients. - s.incrementClients() - defer s.decrementClients() - if _, err := s.checkAuth(r, true); err != nil { - jsonAuthFail(w, r, s) - return - } - jsonRPCRead(w, r, s) - }) - - // Websocket endpoint. - rpcServeMux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { - authenticated, err := s.checkAuth(r, false) - if err != nil { - http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) - return - } - - // Attempt to upgrade the connection to a websocket connection - // using the default size for read/write buffers. - ws, err := websocket.Upgrade(w, r, nil, 0, 0) - if err != nil { - if _, ok := err.(websocket.HandshakeError); !ok { - rpcsLog.Errorf("Unexpected websocket error: %v", - err) - } - return - } - s.WebsocketHandler(ws, r.RemoteAddr, authenticated) - }) - - for _, listener := range s.listeners { - s.wg.Add(1) - go func(listener net.Listener) { - rpcsLog.Infof("RPC server listening on %s", listener.Addr()) - httpServer.Serve(listener) - rpcsLog.Tracef("RPC listener done for %s", listener.Addr()) - s.wg.Done() - }(listener) - } - - s.ntfnMgr.Start() -} - -// httpStatusLine returns a response Status-Line (RFC 2616 Section 6.1) -// for the given request and response status code. This function was lifted and -// adapted from the standard library HTTP server code since it's not exported. -func (s *rpcServer) httpStatusLine(req *http.Request, code int) string { - // Fast path: - key := code - proto11 := req.ProtoAtLeast(1, 1) - if !proto11 { - key = -key - } - s.statusLock.RLock() - line, ok := s.statusLines[key] - s.statusLock.RUnlock() - if ok { - return line - } - - // Slow path: - proto := "HTTP/1.0" - if proto11 { - proto = "HTTP/1.1" - } - codeStr := strconv.Itoa(code) - text := http.StatusText(code) - if text != "" { - line = proto + " " + codeStr + " " + text + "\r\n" - s.statusLock.Lock() - s.statusLines[key] = line - s.statusLock.Unlock() - } else { - text = "status code " + codeStr - line = proto + " " + codeStr + " " + text + "\r\n" - } - - return line -} - -// writeHTTPResponseHeaders writes the necessary response headers prior to -// writing an HTTP body given a request to use for protocol negotiation, headers -// to write, a status code, and a writer. -func (s *rpcServer) writeHTTPResponseHeaders(req *http.Request, headers http.Header, code int, w io.Writer) error { - _, err := io.WriteString(w, s.httpStatusLine(req, code)) - if err != nil { - return err - } - - err = headers.Write(w) - if err != nil { - return err - } - - _, err = io.WriteString(w, "\r\n") - if err != nil { - return err - } - - return nil -} - -// limitConnections responds with a 503 service unavailable and returns true if -// adding another client would exceed the maximum allow RPC clients. -// -// This function is safe for concurrent access. -func (s *rpcServer) limitConnections(w http.ResponseWriter, remoteAddr string) bool { - if int(atomic.LoadInt32(&s.numClients)+1) > cfg.RPCMaxClients { - rpcsLog.Infof("Max RPC clients exceeded [%d] - "+ - "disconnecting client %s", cfg.RPCMaxClients, - remoteAddr) - http.Error(w, "503 Too busy. Try again later.", - http.StatusServiceUnavailable) - return true - } - return false -} - -// incrementClients adds one to the number of connected RPC clients. Note -// this only applies to standard clients. Websocket clients have their own -// limits and are tracked separately. -// -// This function is safe for concurrent access. -func (s *rpcServer) incrementClients() { - atomic.AddInt32(&s.numClients, 1) -} - -// decrementClients subtracts one from the number of connected RPC clients. -// Note this only applies to standard clients. Websocket clients have their own -// limits and are tracked separately. -// -// This function is safe for concurrent access. -func (s *rpcServer) decrementClients() { - atomic.AddInt32(&s.numClients, -1) -} - -// checkAuth checks the HTTP Basic authentication supplied by a wallet -// or RPC client in the HTTP request r. If the supplied authentication -// does not match the username and password expected, a non-nil error is -// returned. -// -// This check is time-constant. -func (s *rpcServer) checkAuth(r *http.Request, require bool) (bool, error) { - authhdr := r.Header["Authorization"] - if len(authhdr) <= 0 { - if require { - rpcsLog.Warnf("RPC authentication failure from %s", - r.RemoteAddr) - return false, errors.New("auth failure") - } - - return false, nil - } - - authsha := fastsha256.Sum256([]byte(authhdr[0])) - cmp := subtle.ConstantTimeCompare(authsha[:], s.authsha[:]) - if cmp != 1 { - rpcsLog.Warnf("RPC authentication failure from %s", r.RemoteAddr) - return false, errors.New("auth failure") - } - return true, nil -} - -// Stop is used by server.go to stop the rpc listener. -func (s *rpcServer) Stop() error { - if atomic.AddInt32(&s.shutdown, 1) != 1 { - rpcsLog.Infof("RPC server is already in the process of shutting down") - return nil - } - rpcsLog.Warnf("RPC server shutting down") - for _, listener := range s.listeners { - err := listener.Close() - if err != nil { - rpcsLog.Errorf("Problem shutting down rpc: %v", err) - return err - } - } - s.ntfnMgr.Shutdown() - s.ntfnMgr.WaitForShutdown() - close(s.quit) - s.wg.Wait() - rpcsLog.Infof("RPC server shutdown complete") - return nil -} - -// genCertPair generates a key/cert pair to the paths provided. -func genCertPair(certFile, keyFile string) error { - rpcsLog.Infof("Generating TLS certificates...") - - org := "btcd autogenerated cert" - validUntil := time.Now().Add(10 * 365 * 24 * time.Hour) - cert, key, err := btcutil.NewTLSCertPair(org, validUntil, nil) - if err != nil { - return err - } - - // Write cert and key files. - if err = ioutil.WriteFile(certFile, cert, 0666); err != nil { - return err - } - if err = ioutil.WriteFile(keyFile, key, 0600); err != nil { - os.Remove(certFile) - return err - } - - rpcsLog.Infof("Done generating TLS certificates") - return nil -} - -// newRPCServer returns a new instance of the rpcServer struct. -func newRPCServer(listenAddrs []string, s *server) (*rpcServer, error) { - login := cfg.RPCUser + ":" + cfg.RPCPass - auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(login)) - rpc := rpcServer{ - authsha: fastsha256.Sum256([]byte(auth)), - server: s, - statusLines: make(map[int]string), - workState: newWorkState(), - gbtWorkState: newGbtWorkState(s.timeSource), - quit: make(chan int), - } - rpc.ntfnMgr = newWsNotificationManager(&rpc) - - // Setup TLS if not disabled. - listenFunc := net.Listen - if !cfg.DisableTLS { - // Generate the TLS cert and key file if both don't already - // exist. - if !fileExists(cfg.RPCKey) && !fileExists(cfg.RPCCert) { - err := genCertPair(cfg.RPCCert, cfg.RPCKey) - if err != nil { - return nil, err - } - } - keypair, err := tls.LoadX509KeyPair(cfg.RPCCert, cfg.RPCKey) - if err != nil { - return nil, err - } - - tlsConfig := tls.Config{ - Certificates: []tls.Certificate{keypair}, - MinVersion: tls.VersionTLS12, - } - - // Change the standard net.Listen function to the tls one. - listenFunc = func(net string, laddr string) (net.Listener, error) { - return tls.Listen(net, laddr, &tlsConfig) - } - } - - // TODO(oga) this code is similar to that in server, should be - // factored into something shared. - ipv4ListenAddrs, ipv6ListenAddrs, _, err := parseListeners(listenAddrs) - if err != nil { - return nil, err - } - listeners := make([]net.Listener, 0, - len(ipv6ListenAddrs)+len(ipv4ListenAddrs)) - for _, addr := range ipv4ListenAddrs { - listener, err := listenFunc("tcp4", addr) - if err != nil { - rpcsLog.Warnf("Can't listen on %s: %v", addr, err) - continue - } - listeners = append(listeners, listener) - } - - for _, addr := range ipv6ListenAddrs { - listener, err := listenFunc("tcp6", addr) - if err != nil { - rpcsLog.Warnf("Can't listen on %s: %v", addr, err) - continue - } - listeners = append(listeners, listener) - } - if len(listeners) == 0 { - return nil, errors.New("RPCS: No valid listen address") - } - - rpc.listeners = listeners - - return &rpc, nil -} - -// jsonAuthFail sends a message back to the client if the http auth is rejected. -func jsonAuthFail(w http.ResponseWriter, r *http.Request, s *rpcServer) { - w.Header().Add("WWW-Authenticate", `Basic realm="btcd RPC"`) - http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) -} - -// jsonRPCRead is the RPC wrapper around the jsonRead function to handle reading -// and responding to RPC messages. -func jsonRPCRead(w http.ResponseWriter, r *http.Request, s *rpcServer) { - if atomic.LoadInt32(&s.shutdown) != 0 { - return - } - body, err := btcjson.GetRaw(r.Body) - if err != nil { - rpcsLog.Errorf("Error getting json message: %v", err) - return - } - - // Unfortunately, the http server doesn't provide the ability to - // change the read deadline for the new connection and having one breaks - // long polling. However, not having a read deadline on the initial - // connection would mean clients can connect and idle forever. Thus, - // hijack the connecton from the HTTP server, clear the read deadline, - // and handle writing the response manually. - hj, ok := w.(http.Hijacker) - if !ok { - errMsg := "webserver doesn't support hijacking" - rpcsLog.Warnf(errMsg) - errCode := http.StatusInternalServerError - http.Error(w, strconv.FormatInt(int64(errCode), 10)+" "+errMsg, - errCode) - return - } - conn, buf, err := hj.Hijack() - if err != nil { - rpcsLog.Warnf("Failed to hijack HTTP connection: %v", err) - errCode := http.StatusInternalServerError - http.Error(w, strconv.FormatInt(int64(errCode), 10)+" "+ - err.Error(), errCode) - return - } - defer conn.Close() - defer buf.Flush() - conn.SetReadDeadline(timeZeroVal) - - var reply btcjson.Reply - cmd, jsonErr := parseCmd(body) - if cmd != nil { - // Unmarshaling at least a valid JSON-RPC message succeeded. - // Use the provided id for errors. - id := cmd.Id() - reply.Id = &id - } - if jsonErr != nil { - reply.Error = jsonErr - } else { - // Setup a close notifier. Since the connection is hijacked, - // the CloseNotifer on the ResponseWriter is not available. - closeChan := make(chan struct{}, 1) - go func() { - _, err := conn.Read(make([]byte, 1)) - if err != nil { - close(closeChan) - } - }() - - reply = standardCmdReply(cmd, s, closeChan) - } - - rpcsLog.Tracef("reply: %v", reply) - - err = s.writeHTTPResponseHeaders(r, w.Header(), http.StatusOK, buf) - if err != nil { - rpcsLog.Error(err) - return - } - - msg, err := btcjson.MarshallAndSend(reply, buf) - if err != nil { - rpcsLog.Error(err) - return - } - rpcsLog.Tracef(msg) -} - // handleUnimplemented is a temporary handler for commands that we should // support but do not. func handleUnimplemented(s *rpcServer, cmd btcjson.Cmd, closeChan <-chan struct{}) (interface{}, error) { @@ -1143,6 +730,26 @@ func handleGetBestBlockHash(s *rpcServer, cmd btcjson.Cmd, closeChan <-chan stru return sha.String(), nil } +// getDifficultyRatio returns the proof-of-work difficulty as a multiple of the +// minimum difficulty using the passed bits field from the header of a block. +func getDifficultyRatio(bits uint32) float64 { + // The minimum difficulty is the max possible proof-of-work limit bits + // converted back to a number. Note this is not the same as the the + // proof of work limit directly because the block difficulty is encoded + // in a block with the compact form which loses precision. + max := blockchain.CompactToBig(activeNetParams.PowLimitBits) + target := blockchain.CompactToBig(bits) + + difficulty := new(big.Rat).SetFrac(max, target) + outString := difficulty.FloatString(2) + diff, err := strconv.ParseFloat(outString, 64) + if err != nil { + rpcsLog.Errorf("Cannot get difficulty: %v", err) + return 0 + } + return diff +} + // handleGetBlock implements the getblock command. func handleGetBlock(s *rpcServer, cmd btcjson.Cmd, closeChan <-chan struct{}) (interface{}, error) { c := cmd.(*btcjson.GetBlockCmd) @@ -3198,6 +2805,23 @@ func handleSubmitBlock(s *rpcServer, cmd btcjson.Cmd, closeChan <-chan struct{}) return nil, nil } +// handleValidateAddress implements the validateaddress command. +func handleValidateAddress(s *rpcServer, cmd interface{}, closeChan <-chan struct{}) (interface{}, error) { + c := cmd.(*btcjson.ValidateAddressCmd) + + result := btcjson.ValidateAddressChainResult{} + addr, err := btcutil.DecodeAddress(c.Address, activeNetParams.Params) + if err != nil { + // Return the default value (false) for IsValid. + return result, nil + } + + result.Address = addr.EncodeAddress() + result.IsValid = true + + return result, nil +} + func verifyChain(db database.Db, level, depth int32, timeSource blockchain.MedianTimeSource) error { _, curHeight64, err := db.NewestSha() if err != nil { @@ -3246,23 +2870,6 @@ func verifyChain(db database.Db, level, depth int32, timeSource blockchain.Media return nil } -// handleValidateAddress implements the validateaddress command. -func handleValidateAddress(s *rpcServer, cmd btcjson.Cmd, closeChan <-chan struct{}) (interface{}, error) { - c := cmd.(*btcjson.ValidateAddressCmd) - - result := btcjson.ValidateAddressResult{} - addr, err := btcutil.DecodeAddress(c.Address, activeNetParams.Params) - if err != nil { - // Return the default value (false) for IsValid. - return result, nil - } - - result.Address = addr.EncodeAddress() - result.IsValid = true - - return result, nil -} - // handleVerifyChain implements the verifychain command. func handleVerifyChain(s *rpcServer, cmd btcjson.Cmd, closeChan <-chan struct{}) (interface{}, error) { c := cmd.(*btcjson.VerifyChainCmd) @@ -3333,28 +2940,170 @@ func handleVerifyMessage(s *rpcServer, cmd btcjson.Cmd, closeChan <-chan struct{ return address.EncodeAddress() == c.Address, nil } -// parseCmd parses a marshaled known command, returning any errors as a -// btcjson.Error that can be used in replies. The returned cmd may still -// be non-nil if b is at least a valid marshaled JSON-RPC message. -func parseCmd(b []byte) (btcjson.Cmd, *btcjson.Error) { - cmd, err := btcjson.ParseMarshaledCmd(b) - if err != nil { - jsonErr, ok := err.(btcjson.Error) - if !ok { - jsonErr = btcjson.Error{ - Code: btcjson.ErrParse.Code, - Message: err.Error(), - } - } - return cmd, &jsonErr +// rpcServer holds the items the rpc server may need to access (config, +// shutdown, main server, etc.) +type rpcServer struct { + started int32 + shutdown int32 + server *server + authsha [fastsha256.Size]byte + ntfnMgr *wsNotificationManager + numClients int32 + statusLines map[int]string + statusLock sync.RWMutex + wg sync.WaitGroup + listeners []net.Listener + workState *workState + gbtWorkState *gbtWorkState + quit chan int +} + +// httpStatusLine returns a response Status-Line (RFC 2616 Section 6.1) +// for the given request and response status code. This function was lifted and +// adapted from the standard library HTTP server code since it's not exported. +func (s *rpcServer) httpStatusLine(req *http.Request, code int) string { + // Fast path: + key := code + proto11 := req.ProtoAtLeast(1, 1) + if !proto11 { + key = -key } - return cmd, nil + s.statusLock.RLock() + line, ok := s.statusLines[key] + s.statusLock.RUnlock() + if ok { + return line + } + + // Slow path: + proto := "HTTP/1.0" + if proto11 { + proto = "HTTP/1.1" + } + codeStr := strconv.Itoa(code) + text := http.StatusText(code) + if text != "" { + line = proto + " " + codeStr + " " + text + "\r\n" + s.statusLock.Lock() + s.statusLines[key] = line + s.statusLock.Unlock() + } else { + text = "status code " + codeStr + line = proto + " " + codeStr + " " + text + "\r\n" + } + + return line +} + +// writeHTTPResponseHeaders writes the necessary response headers prior to +// writing an HTTP body given a request to use for protocol negotiation, headers +// to write, a status code, and a writer. +func (s *rpcServer) writeHTTPResponseHeaders(req *http.Request, headers http.Header, code int, w io.Writer) error { + _, err := io.WriteString(w, s.httpStatusLine(req, code)) + if err != nil { + return err + } + + err = headers.Write(w) + if err != nil { + return err + } + + _, err = io.WriteString(w, "\r\n") + if err != nil { + return err + } + + return nil +} + +// Stop is used by server.go to stop the rpc listener. +func (s *rpcServer) Stop() error { + if atomic.AddInt32(&s.shutdown, 1) != 1 { + rpcsLog.Infof("RPC server is already in the process of shutting down") + return nil + } + rpcsLog.Warnf("RPC server shutting down") + for _, listener := range s.listeners { + err := listener.Close() + if err != nil { + rpcsLog.Errorf("Problem shutting down rpc: %v", err) + return err + } + } + s.ntfnMgr.Shutdown() + s.ntfnMgr.WaitForShutdown() + close(s.quit) + s.wg.Wait() + rpcsLog.Infof("RPC server shutdown complete") + return nil +} + +// limitConnections responds with a 503 service unavailable and returns true if +// adding another client would exceed the maximum allow RPC clients. +// +// This function is safe for concurrent access. +func (s *rpcServer) limitConnections(w http.ResponseWriter, remoteAddr string) bool { + if int(atomic.LoadInt32(&s.numClients)+1) > cfg.RPCMaxClients { + rpcsLog.Infof("Max RPC clients exceeded [%d] - "+ + "disconnecting client %s", cfg.RPCMaxClients, + remoteAddr) + http.Error(w, "503 Too busy. Try again later.", + http.StatusServiceUnavailable) + return true + } + return false +} + +// incrementClients adds one to the number of connected RPC clients. Note +// this only applies to standard clients. Websocket clients have their own +// limits and are tracked separately. +// +// This function is safe for concurrent access. +func (s *rpcServer) incrementClients() { + atomic.AddInt32(&s.numClients, 1) +} + +// decrementClients subtracts one from the number of connected RPC clients. +// Note this only applies to standard clients. Websocket clients have their own +// limits and are tracked separately. +// +// This function is safe for concurrent access. +func (s *rpcServer) decrementClients() { + atomic.AddInt32(&s.numClients, -1) +} + +// checkAuth checks the HTTP Basic authentication supplied by a wallet +// or RPC client in the HTTP request r. If the supplied authentication +// does not match the username and password expected, a non-nil error is +// returned. +// +// This check is time-constant. +func (s *rpcServer) checkAuth(r *http.Request, require bool) (bool, error) { + authhdr := r.Header["Authorization"] + if len(authhdr) <= 0 { + if require { + rpcsLog.Warnf("RPC authentication failure from %s", + r.RemoteAddr) + return false, errors.New("auth failure") + } + + return false, nil + } + + authsha := fastsha256.Sum256([]byte(authhdr[0])) + cmp := subtle.ConstantTimeCompare(authsha[:], s.authsha[:]) + if cmp != 1 { + rpcsLog.Warnf("RPC authentication failure from %s", r.RemoteAddr) + return false, errors.New("auth failure") + } + return true, nil } // standardCmdReply checks that a parsed command is a standard // Bitcoin JSON-RPC command and runs the proper handler to reply to the // command. -func standardCmdReply(cmd btcjson.Cmd, s *rpcServer, closeChan <-chan struct{}) (reply btcjson.Reply) { +func (s *rpcServer) standardCmdReply(cmd btcjson.Cmd, closeChan <-chan struct{}) (reply btcjson.Reply) { id := cmd.Id() reply.Id = &id @@ -3395,24 +3144,275 @@ handled: return reply } -// getDifficultyRatio returns the proof-of-work difficulty as a multiple of the -// minimum difficulty using the passed bits field from the header of a block. -func getDifficultyRatio(bits uint32) float64 { - // The minimum difficulty is the max possible proof-of-work limit bits - // converted back to a number. Note this is not the same as the the - // proof of work limit directly because the block difficulty is encoded - // in a block with the compact form which loses precision. - max := blockchain.CompactToBig(activeNetParams.PowLimitBits) - target := blockchain.CompactToBig(bits) - - difficulty := new(big.Rat).SetFrac(max, target) - outString := difficulty.FloatString(2) - diff, err := strconv.ParseFloat(outString, 64) +// parseCmd parses a marshaled known command, returning any errors as a +// btcjson.Error that can be used in replies. The returned cmd may still +// be non-nil if b is at least a valid marshaled JSON-RPC message. +func parseCmd(b []byte) (btcjson.Cmd, *btcjson.Error) { + cmd, err := btcjson.ParseMarshaledCmd(b) if err != nil { - rpcsLog.Errorf("Cannot get difficulty: %v", err) - return 0 + jsonErr, ok := err.(btcjson.Error) + if !ok { + jsonErr = btcjson.Error{ + Code: btcjson.ErrParse.Code, + Message: err.Error(), + } + } + return cmd, &jsonErr } - return diff + return cmd, nil +} + +// jsonRPCRead is the RPC wrapper around the jsonRead function to handle reading +// and responding to RPC messages. +func (s *rpcServer) jsonRPCRead(w http.ResponseWriter, r *http.Request) { + if atomic.LoadInt32(&s.shutdown) != 0 { + return + } + body, err := btcjson.GetRaw(r.Body) + if err != nil { + rpcsLog.Errorf("Error getting json message: %v", err) + return + } + + // Unfortunately, the http server doesn't provide the ability to + // change the read deadline for the new connection and having one breaks + // long polling. However, not having a read deadline on the initial + // connection would mean clients can connect and idle forever. Thus, + // hijack the connecton from the HTTP server, clear the read deadline, + // and handle writing the response manually. + hj, ok := w.(http.Hijacker) + if !ok { + errMsg := "webserver doesn't support hijacking" + rpcsLog.Warnf(errMsg) + errCode := http.StatusInternalServerError + http.Error(w, strconv.FormatInt(int64(errCode), 10)+" "+errMsg, + errCode) + return + } + conn, buf, err := hj.Hijack() + if err != nil { + rpcsLog.Warnf("Failed to hijack HTTP connection: %v", err) + errCode := http.StatusInternalServerError + http.Error(w, strconv.FormatInt(int64(errCode), 10)+" "+ + err.Error(), errCode) + return + } + defer conn.Close() + defer buf.Flush() + conn.SetReadDeadline(timeZeroVal) + + var reply btcjson.Reply + cmd, jsonErr := parseCmd(body) + if cmd != nil { + // Unmarshaling at least a valid JSON-RPC message succeeded. + // Use the provided id for errors. + id := cmd.Id() + reply.Id = &id + } + if jsonErr != nil { + reply.Error = jsonErr + } else { + // Setup a close notifier. Since the connection is hijacked, + // the CloseNotifer on the ResponseWriter is not available. + closeChan := make(chan struct{}, 1) + go func() { + _, err := conn.Read(make([]byte, 1)) + if err != nil { + close(closeChan) + } + }() + + reply = s.standardCmdReply(cmd, closeChan) + } + + rpcsLog.Tracef("reply: %v", reply) + + err = s.writeHTTPResponseHeaders(r, w.Header(), http.StatusOK, buf) + if err != nil { + rpcsLog.Error(err) + return + } + + msg, err := btcjson.MarshallAndSend(reply, buf) + if err != nil { + rpcsLog.Error(err) + return + } + rpcsLog.Tracef(msg) +} + +// jsonAuthFail sends a message back to the client if the http auth is rejected. +func jsonAuthFail(w http.ResponseWriter) { + w.Header().Add("WWW-Authenticate", `Basic realm="btcd RPC"`) + http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) +} + +// Start is used by server.go to start the rpc listener. +func (s *rpcServer) Start() { + if atomic.AddInt32(&s.started, 1) != 1 { + return + } + + rpcsLog.Trace("Starting RPC server") + rpcServeMux := http.NewServeMux() + httpServer := &http.Server{ + Handler: rpcServeMux, + + // Timeout connections which don't complete the initial + // handshake within the allowed timeframe. + ReadTimeout: time.Second * rpcAuthTimeoutSeconds, + } + rpcServeMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Connection", "close") + w.Header().Set("Content-Type", "application/json") + r.Close = true + + // Limit the number of connections to max allowed. + if s.limitConnections(w, r.RemoteAddr) { + return + } + + // Keep track of the number of connected clients. + s.incrementClients() + defer s.decrementClients() + if _, err := s.checkAuth(r, true); err != nil { + jsonAuthFail(w) + return + } + s.jsonRPCRead(w, r) + }) + + // Websocket endpoint. + rpcServeMux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { + authenticated, err := s.checkAuth(r, false) + if err != nil { + http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) + return + } + + // Attempt to upgrade the connection to a websocket connection + // using the default size for read/write buffers. + ws, err := websocket.Upgrade(w, r, nil, 0, 0) + if err != nil { + if _, ok := err.(websocket.HandshakeError); !ok { + rpcsLog.Errorf("Unexpected websocket error: %v", + err) + } + return + } + s.WebsocketHandler(ws, r.RemoteAddr, authenticated) + }) + + for _, listener := range s.listeners { + s.wg.Add(1) + go func(listener net.Listener) { + rpcsLog.Infof("RPC server listening on %s", listener.Addr()) + httpServer.Serve(listener) + rpcsLog.Tracef("RPC listener done for %s", listener.Addr()) + s.wg.Done() + }(listener) + } + + s.ntfnMgr.Start() +} + +// genCertPair generates a key/cert pair to the paths provided. +func genCertPair(certFile, keyFile string) error { + rpcsLog.Infof("Generating TLS certificates...") + + org := "btcd autogenerated cert" + validUntil := time.Now().Add(10 * 365 * 24 * time.Hour) + cert, key, err := btcutil.NewTLSCertPair(org, validUntil, nil) + if err != nil { + return err + } + + // Write cert and key files. + if err = ioutil.WriteFile(certFile, cert, 0666); err != nil { + return err + } + if err = ioutil.WriteFile(keyFile, key, 0600); err != nil { + os.Remove(certFile) + return err + } + + rpcsLog.Infof("Done generating TLS certificates") + return nil +} + +// newRPCServer returns a new instance of the rpcServer struct. +func newRPCServer(listenAddrs []string, s *server) (*rpcServer, error) { + login := cfg.RPCUser + ":" + cfg.RPCPass + auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(login)) + rpc := rpcServer{ + authsha: fastsha256.Sum256([]byte(auth)), + server: s, + statusLines: make(map[int]string), + workState: newWorkState(), + gbtWorkState: newGbtWorkState(s.timeSource), + quit: make(chan int), + } + rpc.ntfnMgr = newWsNotificationManager(&rpc) + + // Setup TLS if not disabled. + listenFunc := net.Listen + if !cfg.DisableTLS { + // Generate the TLS cert and key file if both don't already + // exist. + if !fileExists(cfg.RPCKey) && !fileExists(cfg.RPCCert) { + err := genCertPair(cfg.RPCCert, cfg.RPCKey) + if err != nil { + return nil, err + } + } + keypair, err := tls.LoadX509KeyPair(cfg.RPCCert, cfg.RPCKey) + if err != nil { + return nil, err + } + + tlsConfig := tls.Config{ + Certificates: []tls.Certificate{keypair}, + MinVersion: tls.VersionTLS12, + } + + // Change the standard net.Listen function to the tls one. + listenFunc = func(net string, laddr string) (net.Listener, error) { + return tls.Listen(net, laddr, &tlsConfig) + } + } + + // TODO(oga) this code is similar to that in server, should be + // factored into something shared. + ipv4ListenAddrs, ipv6ListenAddrs, _, err := parseListeners(listenAddrs) + if err != nil { + return nil, err + } + listeners := make([]net.Listener, 0, + len(ipv6ListenAddrs)+len(ipv4ListenAddrs)) + for _, addr := range ipv4ListenAddrs { + listener, err := listenFunc("tcp4", addr) + if err != nil { + rpcsLog.Warnf("Can't listen on %s: %v", addr, err) + continue + } + listeners = append(listeners, listener) + } + + for _, addr := range ipv6ListenAddrs { + listener, err := listenFunc("tcp6", addr) + if err != nil { + rpcsLog.Warnf("Can't listen on %s: %v", addr, err) + continue + } + listeners = append(listeners, listener) + } + if len(listeners) == 0 { + return nil, errors.New("RPCS: No valid listen address") + } + + rpc.listeners = listeners + + return &rpc, nil } func init() { diff --git a/rpcwebsocket.go b/rpcwebsocket.go index f169baf2..71bc2dda 100644 --- a/rpcwebsocket.go +++ b/rpcwebsocket.go @@ -1024,7 +1024,7 @@ func (c *wsClient) handleMessage(msg []byte) { if !ok { // No websocket-specific handler so handle like a legacy // RPC connection. - response := standardCmdReply(cmd, c.server, nil) + response := c.server.standardCmdReply(cmd, nil) reply, err := json.Marshal(response) if err != nil { rpcsLog.Errorf("Failed to marshal reply for <%s> "+