Fix tx file serialization and tests.

This commit is contained in:
Josh Rickmar 2014-01-06 13:35:07 -05:00
parent e8265eca41
commit 40675d1bef
2 changed files with 55 additions and 29 deletions

View file

@ -51,24 +51,29 @@ type ReaderFromVersion interface {
io.WriterTo io.WriterTo
} }
// Various versions. // Various UTXO file versions.
var ( const (
// First file version used. utxoVersFirst uint32 = iota
utxoVersFirst uint32 = 0 )
txVersFirst uint32 = 0
// Various Tx file versions.
const (
txVersFirst uint32 = iota
// txVersRecvTxIndex is the version where the txout index // txVersRecvTxIndex is the version where the txout index
// was added to the RecvTx struct. // was added to the RecvTx struct.
txVersRecvTxIndex uint32 = 1 txVersRecvTxIndex
// txVersMarkSentChange is the version where serialized SentTx // txVersMarkSentChange is the version where serialized SentTx
// added a flags field, used for marking a sent transaction // added a flags field, used for marking a sent transaction
// as change. // as change.
txVersMarkSentChange uint32 = 2 txVersMarkSentChange
)
// Current versions // Current versions.
const (
utxoVersCurrent = utxoVersFirst utxoVersCurrent = utxoVersFirst
txVersCurrent = txVersRecvTxIndex txVersCurrent = txVersMarkSentChange
) )
// UtxoStore is a type used for holding all Utxo structures for all // UtxoStore is a type used for holding all Utxo structures for all
@ -344,6 +349,18 @@ func (p *Pair) WriteTo(w io.Writer) (int64, error) {
} }
written += int64(nw) written += int64(nw)
// Set and write flags.
flags := byte(0)
if p.Change {
flags |= 1 << 0
}
flagBytes := []byte{flags}
nw, err = w.Write(flagBytes)
if err != nil {
return written + int64(nw), err
}
written += int64(nw)
return written, nil return written, nil
} }
@ -699,7 +716,7 @@ func (s *PkScript) WriteTo(w io.Writer) (n int64, err error) {
func (txs *TxStore) ReadFrom(r io.Reader) (int64, error) { func (txs *TxStore) ReadFrom(r io.Reader) (int64, error) {
var read int64 var read int64
// Read the file version. This is currently not used. // Read the file version.
versionBytes := make([]byte, 4) // bytes for a uint32 versionBytes := make([]byte, 4) // bytes for a uint32
n, err := r.Read(versionBytes) n, err := r.Read(versionBytes)
if err != nil { if err != nil {
@ -762,9 +779,9 @@ func (txs *TxStore) ReadFrom(r io.Reader) (int64, error) {
func (txs *TxStore) WriteTo(w io.Writer) (int64, error) { func (txs *TxStore) WriteTo(w io.Writer) (int64, error) {
var written int64 var written int64
// Write file version. This is currently not used. // Write file version.
versionBytes := make([]byte, 4) // bytes for a uint32 versionBytes := make([]byte, 4) // bytes for a uint32
binary.LittleEndian.PutUint32(versionBytes, utxoVersCurrent) binary.LittleEndian.PutUint32(versionBytes, txVersCurrent)
n, err := w.Write(versionBytes) n, err := w.Write(versionBytes)
if err != nil { if err != nil {
return int64(n), err return int64(n), err
@ -773,30 +790,32 @@ func (txs *TxStore) WriteTo(w io.Writer) (int64, error) {
store := ([]interface{})(*txs) store := ([]interface{})(*txs)
for _, tx := range store { for _, tx := range store {
// Write header for tx.
var header byte
switch tx.(type) { switch tx.(type) {
case *RecvTx: case *RecvTx:
n, err := binaryWrite(w, binary.LittleEndian, recvTxHeader) header = recvTxHeader
if err != nil {
return written + n, err
}
written += n
case *SendTx: case *SendTx:
n, err := binaryWrite(w, binary.LittleEndian, sendTxHeader) header = sendTxHeader
if err != nil {
return written + n, err
}
written += n
default: default:
return written, fmt.Errorf("unknown type in TxStore") return written, fmt.Errorf("unknown type in TxStore")
} }
wt := tx.(io.WriterTo) headerBytes := []byte{header}
n, err := wt.WriteTo(w) n, err := w.Write(headerBytes)
if err != nil { if err != nil {
return written + n, err return written + int64(n), err
} }
written += n written += int64(n)
// Write tx.
wt := tx.(io.WriterTo)
n64, err := wt.WriteTo(w)
if err != nil {
return written + n64, err
}
written += n64
} }
return written, nil return written, nil
} }
@ -964,6 +983,7 @@ func (tx *RecvTx) ReadFrom(r io.Reader) (n int64, err error) {
// w in the format: // w in the format:
// //
// TxID (32 bytes) // TxID (32 bytes)
// TxOutIdx (4 bytes, little endian)
// TimeReceived (8 bytes, little endian) // TimeReceived (8 bytes, little endian)
// BlockHeight (4 bytes, little endian) // BlockHeight (4 bytes, little endian)
// BlockHash (32 bytes) // BlockHash (32 bytes)
@ -974,6 +994,7 @@ func (tx *RecvTx) ReadFrom(r io.Reader) (n int64, err error) {
func (tx *RecvTx) WriteTo(w io.Writer) (n int64, err error) { func (tx *RecvTx) WriteTo(w io.Writer) (n int64, err error) {
datas := []interface{}{ datas := []interface{}{
&tx.TxID, &tx.TxID,
&tx.TxOutIdx,
&tx.TimeReceived, &tx.TimeReceived,
&tx.BlockHeight, &tx.BlockHeight,
&tx.BlockHash, &tx.BlockHash,
@ -1046,8 +1067,12 @@ func (tx *SendTx) ReadFromVersion(vers uint32, r io.Reader) (n int64, err error)
} }
for _, data := range datas { for _, data := range datas {
switch e := data.(type) { switch e := data.(type) {
case ReaderFromVersion:
read, err = e.ReadFromVersion(vers, r)
case io.ReaderFrom: case io.ReaderFrom:
read, err = e.ReadFrom(r) read, err = e.ReadFrom(r)
default: default:
read, err = binaryRead(r, binary.LittleEndian, data) read, err = binaryRead(r, binary.LittleEndian, data)
} }

View file

@ -33,6 +33,7 @@ var (
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 30, 31,
}, },
TxOutIdx: 0,
BlockHash: [btcwire.HashSize]byte{ BlockHash: [btcwire.HashSize]byte{
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
@ -203,7 +204,7 @@ func TestRecvTxWriteRead(t *testing.T) {
tx := new(RecvTx) tx := new(RecvTx)
n, err = tx.ReadFrom(bytes.NewBuffer(txBytes)) n, err = tx.ReadFrom(bytes.NewBuffer(txBytes))
if err != nil { if err != nil {
t.Error(err) t.Errorf("Read %v bytes before erroring with: %v", n, err)
return return
} }
@ -237,7 +238,7 @@ func TestSendTxWriteRead(t *testing.T) {
tx := new(SendTx) tx := new(SendTx)
n2, err := tx.ReadFrom(bytes.NewBuffer(txBytes)) n2, err := tx.ReadFrom(bytes.NewBuffer(txBytes))
if err != nil { if err != nil {
t.Error(err) t.Errorf("Read %v bytes before erroring with: %v", n2, err)
return return
} }
if n1 != n2 { if n1 != n2 {
@ -279,7 +280,7 @@ func TestTxStoreWriteRead(t *testing.T) {
txs := TxStore{} txs := TxStore{}
n2, err := txs.ReadFrom(bytes.NewBuffer(txsBytes)) n2, err := txs.ReadFrom(bytes.NewBuffer(txsBytes))
if err != nil { if err != nil {
t.Error(err) t.Errorf("Read %v bytes before erroring with: %v", n2, err)
return return
} }
if n1 != n2 { if n1 != n2 {