diff --git a/txscript/engine.go b/txscript/engine.go index 3c2e4142..0fa5cc48 100644 --- a/txscript/engine.go +++ b/txscript/engine.go @@ -5,6 +5,8 @@ package txscript import ( + "bytes" + "crypto/sha256" "fmt" "math/big" @@ -46,7 +48,8 @@ const ( // ScriptVerifyCleanStack defines that the stack must contain only // one stack element after evaluation and that the element must be // true if interpreted as a boolean. This is rule 6 of BIP0062. - // This flag should never be used without the ScriptBip16 flag. + // This flag should never be used without the ScriptBip16 flag nor the + // ScriptVerifyWitness flag. ScriptVerifyCleanStack // ScriptVerifyDERSignatures defines that signatures are required @@ -73,6 +76,14 @@ const ( // ScriptVerifyStrictEncoding defines that signature scripts and // public keys must follow the strict encoding requirements. ScriptVerifyStrictEncoding + + // ScriptVerifyWitness defines whether or not to verify a transaction + // output using a witness program template. + ScriptVerifyWitness + + // ScriptVerifyDiscourageUpgradeableWitnessProgram makes witness + // program with versions 2-16 non-standard. + ScriptVerifyDiscourageUpgradeableWitnessProgram ) const ( @@ -82,6 +93,14 @@ const ( // MaxScriptSize is the maximum allowed length of a raw script. MaxScriptSize = 10000 + + // payToWitnessPubKeyHashDataSize is the size of the witness program's + // data push for a pay-to-witness-pub-key-hash output. + payToWitnessPubKeyHashDataSize = 20 + + // payToWitnessScriptHashDataSize is the size of the witness program's + // data push for a pay-to-witness-script-hash output. + payToWitnessScriptHashDataSize = 32 ) // halforder is used to tame ECDSA malleability (see BIP0062). @@ -101,8 +120,13 @@ type Engine struct { numOps int flags ScriptFlags sigCache *SigCache + hashCache *TxSigHashes bip16 bool // treat execution as pay-to-script-hash savedFirstStack [][]byte // stack from first script for bip16 scripts + witness bool // treat execution as a witness program + witnessVersion int + witnessProgram []byte + inputAmount int64 } // hasFlag returns whether the script engine instance has the passed flag set. @@ -209,6 +233,115 @@ func (vm *Engine) curPC() (script int, off int, err error) { return vm.scriptIdx, vm.scriptOff, nil } +// verifyWitnessProgram validates the stored witness program using the passed +// witness as input. +func (vm *Engine) verifyWitnessProgram(witness [][]byte) error { + if vm.witnessVersion == 0 { + switch len(vm.witnessProgram) { + case payToWitnessPubKeyHashDataSize: // P2WKH + // The witness stack should consist of exactly two + // items: the signature, and the pubkey. + if len(witness) != 2 { + err := fmt.Sprintf("should have exactly two "+ + "items in witness, instead have %v", len(witness)) + return scriptError(ErrWitnessScriptMismatch, err) + } + + // Now we'll resume execution as if it were a regular + // p2pkh transaction. + pkScript, err := payToPubKeyHashScript(vm.witnessProgram) + if err != nil { + return err + } + pops, err := parseScript(pkScript) + if err != nil { + return err + } + + // Set the stack to the provided witness stack, then + // append the pkScript generated above as the next + // script to execute. + vm.scripts = append(vm.scripts, pops) + vm.SetStack(witness) + + case payToWitnessScriptHashDataSize: // P2WSH + + // Additionally, The witness stack MUST NOT be empty at + // this point. + if len(witness) == 0 { + return scriptError(ErrWitnessProgramEmpty, "witness "+ + "program empty passed empty witness") + } + + // Obtain the witness script which should be the last + // element in the passed stack. The size of the script + // MUST NOT exceed the max script size. + witnessScript := witness[len(witness)-1] + if len(witnessScript) > MaxScriptSize { + str := fmt.Sprintf("witnessScript size %d "+ + "is larger than max allowed size %d", + len(witnessScript), MaxScriptSize) + return scriptError(ErrScriptTooBig, str) + } + + // Ensure that the serialized pkScript at the end of + // the witness stack matches the witness program. + witnessHash := sha256.Sum256(witnessScript) + if !bytes.Equal(witnessHash[:], vm.witnessProgram) { + return scriptError(ErrWitnessScriptMismatch, + "witness program hash mismatch") + } + + // With all the validity checks passed, parse the + // script into individual op-codes so w can execute it + // as the next script. + pops, err := parseScript(witnessScript) + if err != nil { + return err + } + + // The hash matched successfully, so use the witness as + // the stack, and set the witnessScript to be the next + // script executed. + vm.scripts = append(vm.scripts, pops) + vm.SetStack(witness[:len(witness)-1]) + + default: + errStr := fmt.Sprintf("length of witness program "+ + "must either be %v or %v bytes, instead is %v bytes", + payToWitnessPubKeyHashDataSize, + payToWitnessScriptHashDataSize, + len(vm.witnessProgram)) + return scriptError(ErrWitnessProgramWrongLength, errStr) + } + } else if vm.hasFlag(ScriptVerifyDiscourageUpgradeableWitnessProgram) { + return fmt.Errorf("new witness program versions invalid: %v", + vm.witnessVersion) + } else { + // If we encounter an unknown witness program version and we + // aren't discouraging future unknown witness based soft-forks, + // then we de-activate the segwit behavior within the VM for + // the remainder of execution. + vm.witness = false + } + + if vm.witness { + // All elements within the witness stack must not be greater + // than the maximum bytes which are allowed to be pushed onto + // the stack. + for _, witElement := range vm.GetStack() { + if len(witElement) > MaxScriptElementSize { + str := fmt.Sprintf("element size %d exceeds "+ + "max allowed size %d", len(witElement), + MaxScriptElementSize) + return scriptError(ErrElementTooBig, str) + } + } + } + + return nil +} + // DisasmPC returns the string for the disassembly of the opcode that will be // next to execute when Step() is called. func (vm *Engine) DisasmPC() (string, error) { @@ -246,6 +379,14 @@ func (vm *Engine) CheckErrorCondition(finalScript bool) error { return scriptError(ErrScriptUnfinished, "error check when script unfinished") } + + // If we're in witness execution mode, and this was the final script, + // then the stack MUST be clean in order to maintain compatibility with + // BIP16. + if finalScript && vm.witness && vm.dstack.Depth() != 1 { + return ErrStackCleanStack + } + if finalScript && vm.hasFlag(ScriptVerifyCleanStack) && vm.dstack.Depth() != 1 { @@ -343,6 +484,14 @@ func (vm *Engine) Step() (done bool, err error) { // Set stack to be the stack from first script minus the // script itself vm.SetStack(vm.savedFirstStack[:len(vm.savedFirstStack)-1]) + } else if (vm.scriptIdx == 1 && vm.witness) || + (vm.scriptIdx == 2 && vm.witness && vm.bip16) { // Nested P2SH. + vm.scriptIdx++ + + witness := vm.tx.TxIn[vm.txIdx].Witness + if err := vm.verifyWitnessProgram(witness); err != nil { + return false, err + } } else { vm.scriptIdx++ } @@ -626,7 +775,9 @@ func (vm *Engine) SetAltStack(data [][]byte) { // NewEngine returns a new script engine for the provided public key script, // transaction, and input index. The flags modify the behavior of the script // engine according to the description provided by each flag. -func NewEngine(scriptPubKey []byte, tx *wire.MsgTx, txIdx int, flags ScriptFlags, sigCache *SigCache) (*Engine, error) { +func NewEngine(scriptPubKey []byte, tx *wire.MsgTx, txIdx int, flags ScriptFlags, + sigCache *SigCache, hashCache *TxSigHashes, inputAmount int64) (*Engine, error) { + // The provided transaction input index must refer to a valid input. if txIdx < 0 || txIdx >= len(tx.TxIn) { str := fmt.Sprintf("transaction input index %d is negative or "+ @@ -735,17 +886,14 @@ func NewEngine(scriptPubKey []byte, tx *wire.MsgTx, txIdx int, flags ScriptFlags witProgram = scriptPubKey case len(tx.TxIn[txIdx].Witness) != 0 && vm.bip16: - pops, err := parseScript(scriptSig) - if err != nil { - return nil, err - } - // The sigScript MUST be *exactly* a single canonical // data push of the witness program, otherwise we // reintroduce malleability. - if len(pops) == 1 && canonicalPush(pops[0]) && - IsWitnessProgram(pops[0].data) { - witProgram = pops[0].data + sigPops := vm.scripts[0] + if len(sigPops) == 1 && canonicalPush(sigPops[0]) && + IsWitnessProgram(sigPops[0].data) { + + witProgram = sigPops[0].data } else { errStr := "signature script for witness " + "nested p2sh is not canonical" diff --git a/txscript/engine_test.go b/txscript/engine_test.go index c7266b68..2e8c522c 100644 --- a/txscript/engine_test.go +++ b/txscript/engine_test.go @@ -54,7 +54,7 @@ func TestBadPC(t *testing.T) { pkScript := mustParseShortForm("NOP") for _, test := range tests { - vm, err := NewEngine(pkScript, tx, 0, 0, nil) + vm, err := NewEngine(pkScript, tx, 0, 0, nil, nil, -1) if err != nil { t.Errorf("Failed to create script: %v", err) } @@ -111,7 +111,7 @@ func TestCheckErrorCondition(t *testing.T) { pkScript := mustParseShortForm("NOP NOP NOP NOP NOP NOP NOP NOP NOP" + " NOP TRUE") - vm, err := NewEngine(pkScript, tx, 0, 0, nil) + vm, err := NewEngine(pkScript, tx, 0, 0, nil, nil, 0) if err != nil { t.Errorf("failed to create script: %v", err) } @@ -187,7 +187,7 @@ func TestInvalidFlagCombinations(t *testing.T) { pkScript := []byte{OP_NOP} for i, test := range tests { - _, err := NewEngine(pkScript, tx, 0, test, nil) + _, err := NewEngine(pkScript, tx, 0, test, nil, nil, -1) if !IsErrorCode(err, ErrInvalidFlags) { t.Fatalf("TestInvalidFlagCombinations #%d unexpected "+ "error: %v", i, err) diff --git a/txscript/error.go b/txscript/error.go index 107a09b8..a47b9436 100644 --- a/txscript/error.go +++ b/txscript/error.go @@ -227,6 +227,34 @@ const ( // reached. ErrUnsatisfiedLockTime + // ErrWitnessProgramEmpty is returned if ScriptVerifyWitness is set and + // the witness stack itself is empty. + ErrWitnessProgramEmpty + + // ErrWitnessScriptMismatch is returned if ScriptVerifyWitness is set + // and the witness itself for a p2wkh witness program isn't *exactly* 2 + // items. + ErrWitnessScriptMismatch + + // ErrWitnessProgramWrongLength is returned if ScriptVerifyWitness is + // set and the length of the witness program violates the length as + // dictated by the current witness version. + ErrWitnessProgramWrongLength + + // ErrWitnessMalleated is returned if ScriptVerifyWitness is set and a + // native p2wsh program is encountered which has a non-empty sigScript. + ErrWitnessMalleated + + // ErrWitnessMalleatedP2SH is returned if ScriptVerifyWitness if set + // and the validation logic for nested p2sh encounters a sigScript + // which isn't *exactyl* a datapush of the witness program. + ErrWitnessMalleatedP2SH + + // ErrWitnessUnexpected is returned if ScriptVerifyWitness is set and a + // transaction includes witness data but doesn't spend an which is a + // witness program (nested or native). + ErrWitnessUnexpected + // numErrorCodes is the maximum error code number used in tests. This // entry MUST be the last entry in the enum. numErrorCodes diff --git a/txscript/opcode.go b/txscript/opcode.go index 2a289d20..327f7639 100644 --- a/txscript/opcode.go +++ b/txscript/opcode.go @@ -2045,10 +2045,6 @@ func opcodeCheckSig(op *parsedOpcode, vm *Engine) error { // Get script starting from the most recent OP_CODESEPARATOR. subScript := vm.subScript() - // Remove the signature since there is no way for a signature to sign - // itself. - subScript = removeOpcodeByData(subScript, fullSigBytes) - // Generate the signature hash based on the signature hash type. var hash []byte if vm.witness { @@ -2065,6 +2061,10 @@ func opcodeCheckSig(op *parsedOpcode, vm *Engine) error { return err } } else { + // Remove the signature since there is no way for a signature + // to sign itself. + subScript = removeOpcodeByData(subScript, fullSigBytes) + hash = calcSignatureHash(subScript, hashType, &vm.tx, vm.txIdx) } @@ -2232,10 +2232,12 @@ func opcodeCheckMultiSig(op *parsedOpcode, vm *Engine) error { // Get script starting from the most recent OP_CODESEPARATOR. script := vm.subScript() - // Remove any of the signatures since there is no way for a signature to - // sign itself. - for _, sigInfo := range signatures { - script = removeOpcodeByData(script, sigInfo.signature) + // Remove the signature in pre-segwit scripts since there is no way for + // a signature to sign itself. + if !vm.witness { + for _, sigInfo := range signatures { + script = removeOpcodeByData(script, sigInfo.signature) + } } success := true diff --git a/txscript/sign_test.go b/txscript/sign_test.go index aed6010b..b97a8a64 100644 --- a/txscript/sign_test.go +++ b/txscript/sign_test.go @@ -53,10 +53,10 @@ func mkGetScript(scripts map[string][]byte) ScriptDB { }) } -func checkScripts(msg string, tx *wire.MsgTx, idx int, sigScript, pkScript []byte) error { +func checkScripts(msg string, tx *wire.MsgTx, idx int, inputAmt int64, sigScript, pkScript []byte) error { tx.TxIn[idx].SignatureScript = sigScript vm, err := NewEngine(pkScript, tx, idx, - ScriptBip16|ScriptVerifyDERSignatures, nil) + ScriptBip16|ScriptVerifyDERSignatures, nil, nil, inputAmt) if err != nil { return fmt.Errorf("failed to make script engine for %s: %v", msg, err) @@ -71,7 +71,7 @@ func checkScripts(msg string, tx *wire.MsgTx, idx int, sigScript, pkScript []byt return nil } -func signAndCheck(msg string, tx *wire.MsgTx, idx int, pkScript []byte, +func signAndCheck(msg string, tx *wire.MsgTx, idx int, inputAmt int64, pkScript []byte, hashType SigHashType, kdb KeyDB, sdb ScriptDB, previousScript []byte) error { @@ -81,7 +81,7 @@ func signAndCheck(msg string, tx *wire.MsgTx, idx int, pkScript []byte, return fmt.Errorf("failed to sign output %s: %v", msg, err) } - return checkScripts(msg, tx, idx, sigScript, pkScript) + return checkScripts(msg, tx, idx, inputAmt, sigScript, pkScript) } func TestSignTxOutput(t *testing.T) { @@ -99,6 +99,7 @@ func TestSignTxOutput(t *testing.T) { SigHashNone | SigHashAnyOneCanPay, SigHashSingle | SigHashAnyOneCanPay, } + inputAmounts := []int64{5, 10, 15} tx := &wire.MsgTx{ Version: 1, TxIn: []*wire.TxIn{ @@ -165,7 +166,7 @@ func TestSignTxOutput(t *testing.T) { "for %s: %v", msg, err) } - if err := signAndCheck(msg, tx, i, pkScript, hashType, + if err := signAndCheck(msg, tx, i, inputAmounts[i], pkScript, hashType, mkGetKey(map[string]addressToKey{ address.EncodeAddress(): {key, false}, }), mkGetScript(nil), nil); err != nil { @@ -226,7 +227,7 @@ func TestSignTxOutput(t *testing.T) { break } - err = checkScripts(msg, tx, i, sigScript, pkScript) + err = checkScripts(msg, tx, i, inputAmounts[i], sigScript, pkScript) if err != nil { t.Errorf("twice signed script invalid for "+ "%s: %v", msg, err) @@ -263,7 +264,8 @@ func TestSignTxOutput(t *testing.T) { "for %s: %v", msg, err) } - if err := signAndCheck(msg, tx, i, pkScript, hashType, + if err := signAndCheck(msg, tx, i, inputAmounts[i], + pkScript, hashType, mkGetKey(map[string]addressToKey{ address.EncodeAddress(): {key, true}, }), mkGetScript(nil), nil); err != nil { @@ -325,7 +327,8 @@ func TestSignTxOutput(t *testing.T) { break } - err = checkScripts(msg, tx, i, sigScript, pkScript) + err = checkScripts(msg, tx, i, inputAmounts[i], + sigScript, pkScript) if err != nil { t.Errorf("twice signed script invalid for "+ "%s: %v", msg, err) @@ -362,7 +365,8 @@ func TestSignTxOutput(t *testing.T) { "for %s: %v", msg, err) } - if err := signAndCheck(msg, tx, i, pkScript, hashType, + if err := signAndCheck(msg, tx, i, inputAmounts[i], + pkScript, hashType, mkGetKey(map[string]addressToKey{ address.EncodeAddress(): {key, false}, }), mkGetScript(nil), nil); err != nil { @@ -424,7 +428,7 @@ func TestSignTxOutput(t *testing.T) { break } - err = checkScripts(msg, tx, i, sigScript, pkScript) + err = checkScripts(msg, tx, i, inputAmounts[i], sigScript, pkScript) if err != nil { t.Errorf("twice signed script invalid for "+ "%s: %v", msg, err) @@ -461,7 +465,8 @@ func TestSignTxOutput(t *testing.T) { "for %s: %v", msg, err) } - if err := signAndCheck(msg, tx, i, pkScript, hashType, + if err := signAndCheck(msg, tx, i, inputAmounts[i], + pkScript, hashType, mkGetKey(map[string]addressToKey{ address.EncodeAddress(): {key, true}, }), mkGetScript(nil), nil); err != nil { @@ -523,7 +528,8 @@ func TestSignTxOutput(t *testing.T) { break } - err = checkScripts(msg, tx, i, sigScript, pkScript) + err = checkScripts(msg, tx, i, inputAmounts[i], + sigScript, pkScript) if err != nil { t.Errorf("twice signed script invalid for "+ "%s: %v", msg, err) @@ -577,8 +583,8 @@ func TestSignTxOutput(t *testing.T) { break } - if err := signAndCheck(msg, tx, i, scriptPkScript, - hashType, + if err := signAndCheck(msg, tx, i, inputAmounts[i], + scriptPkScript, hashType, mkGetKey(map[string]addressToKey{ address.EncodeAddress(): {key, false}, }), mkGetScript(map[string][]byte{ @@ -662,7 +668,8 @@ func TestSignTxOutput(t *testing.T) { break } - err = checkScripts(msg, tx, i, sigScript, scriptPkScript) + err = checkScripts(msg, tx, i, inputAmounts[i], + sigScript, scriptPkScript) if err != nil { t.Errorf("twice signed script invalid for "+ "%s: %v", msg, err) @@ -715,8 +722,8 @@ func TestSignTxOutput(t *testing.T) { break } - if err := signAndCheck(msg, tx, i, scriptPkScript, - hashType, + if err := signAndCheck(msg, tx, i, inputAmounts[i], + scriptPkScript, hashType, mkGetKey(map[string]addressToKey{ address.EncodeAddress(): {key, true}, }), mkGetScript(map[string][]byte{ @@ -800,7 +807,8 @@ func TestSignTxOutput(t *testing.T) { break } - err = checkScripts(msg, tx, i, sigScript, scriptPkScript) + err = checkScripts(msg, tx, i, inputAmounts[i], + sigScript, scriptPkScript) if err != nil { t.Errorf("twice signed script invalid for "+ "%s: %v", msg, err) @@ -853,8 +861,8 @@ func TestSignTxOutput(t *testing.T) { break } - if err := signAndCheck(msg, tx, i, scriptPkScript, - hashType, + if err := signAndCheck(msg, tx, i, inputAmounts[i], + scriptPkScript, hashType, mkGetKey(map[string]addressToKey{ address.EncodeAddress(): {key, false}, }), mkGetScript(map[string][]byte{ @@ -937,7 +945,8 @@ func TestSignTxOutput(t *testing.T) { break } - err = checkScripts(msg, tx, i, sigScript, scriptPkScript) + err = checkScripts(msg, tx, i, inputAmounts[i], + sigScript, scriptPkScript) if err != nil { t.Errorf("twice signed script invalid for "+ "%s: %v", msg, err) @@ -989,8 +998,8 @@ func TestSignTxOutput(t *testing.T) { break } - if err := signAndCheck(msg, tx, i, scriptPkScript, - hashType, + if err := signAndCheck(msg, tx, i, inputAmounts[i], + scriptPkScript, hashType, mkGetKey(map[string]addressToKey{ address.EncodeAddress(): {key, true}, }), mkGetScript(map[string][]byte{ @@ -1073,7 +1082,8 @@ func TestSignTxOutput(t *testing.T) { break } - err = checkScripts(msg, tx, i, sigScript, scriptPkScript) + err = checkScripts(msg, tx, i, inputAmounts[i], + sigScript, scriptPkScript) if err != nil { t.Errorf("twice signed script invalid for "+ "%s: %v", msg, err) @@ -1144,8 +1154,8 @@ func TestSignTxOutput(t *testing.T) { break } - if err := signAndCheck(msg, tx, i, scriptPkScript, - hashType, + if err := signAndCheck(msg, tx, i, inputAmounts[i], + scriptPkScript, hashType, mkGetKey(map[string]addressToKey{ address1.EncodeAddress(): {key1, true}, address2.EncodeAddress(): {key2, true}, @@ -1234,7 +1244,7 @@ func TestSignTxOutput(t *testing.T) { } // Only 1 out of 2 signed, this *should* fail. - if checkScripts(msg, tx, i, sigScript, + if checkScripts(msg, tx, i, inputAmounts[i], sigScript, scriptPkScript) == nil { t.Errorf("part signed script valid for %s", msg) break @@ -1253,7 +1263,7 @@ func TestSignTxOutput(t *testing.T) { break } - err = checkScripts(msg, tx, i, sigScript, + err = checkScripts(msg, tx, i, inputAmounts[i], sigScript, scriptPkScript) if err != nil { t.Errorf("fully signed script invalid for "+ @@ -1340,7 +1350,7 @@ func TestSignTxOutput(t *testing.T) { } // Only 1 out of 2 signed, this *should* fail. - if checkScripts(msg, tx, i, sigScript, + if checkScripts(msg, tx, i, inputAmounts[i], sigScript, scriptPkScript) == nil { t.Errorf("part signed script valid for %s", msg) break @@ -1361,8 +1371,8 @@ func TestSignTxOutput(t *testing.T) { } // Now we should pass. - err = checkScripts(msg, tx, i, sigScript, - scriptPkScript) + err = checkScripts(msg, tx, i, inputAmounts[i], + sigScript, scriptPkScript) if err != nil { t.Errorf("fully signed script invalid for "+ "%s: %v", msg, err) @@ -1635,7 +1645,7 @@ nexttest: tx.AddTxOut(output) for range sigScriptTests[i].inputs { - txin := wire.NewTxIn(coinbaseOutPoint, nil) + txin := wire.NewTxIn(coinbaseOutPoint, nil, nil) tx.AddTxIn(txin) } @@ -1683,8 +1693,8 @@ nexttest: // Validate tx input scripts scriptFlags := ScriptBip16 | ScriptVerifyDERSignatures for j := range tx.TxIn { - vm, err := NewEngine(sigScriptTests[i].inputs[j].txout. - PkScript, tx, j, scriptFlags, nil) + vm, err := NewEngine(sigScriptTests[i]. + inputs[j].txout.PkScript, tx, j, scriptFlags, nil, nil, 0) if err != nil { t.Errorf("cannot create script vm for test %v: %v", sigScriptTests[i].name, err) diff --git a/txscript/standard.go b/txscript/standard.go index 14ea54f4..b8ffc42a 100644 --- a/txscript/standard.go +++ b/txscript/standard.go @@ -37,7 +37,9 @@ const ( ScriptVerifyNullFail | ScriptVerifyCheckLockTimeVerify | ScriptVerifyCheckSequenceVerify | - ScriptVerifyLowS + ScriptVerifyLowS | + ScriptVerifyWitness | + ScriptVerifyDiscourageUpgradeableWitnessProgram ) // ScriptClass is an enumeration for the list of standard types of script.