diff --git a/consensus/v2sig_test.go b/consensus/v2sig_test.go index 463cf98..7fdc230 100644 --- a/consensus/v2sig_test.go +++ b/consensus/v2sig_test.go @@ -18,6 +18,18 @@ import ( "github.com/stretchr/testify/require" ) +func buildSingleZCSigRaw() []byte { + var buf bytes.Buffer + enc := wire.NewEncoder(&buf) + enc.WriteVarint(1) + enc.WriteUint8(types.SigTypeZC) + enc.WriteBytes(make([]byte, 64)) + enc.WriteVarint(0) + enc.WriteVarint(0) + enc.WriteBytes(make([]byte, 64)) + return buf.Bytes() +} + // loadTestTx loads and decodes a hex-encoded transaction from testdata. func loadTestTx(t *testing.T, filename string) *types.Transaction { t.Helper() @@ -147,6 +159,20 @@ func TestVerifyV2Signatures_BadSigCount(t *testing.T) { assert.Error(t, err, "should fail with mismatched sig count") } +func TestVerifyV2Signatures_HTLCWrongSigTag_Bad(t *testing.T) { + tx := &types.Transaction{ + Version: types.VersionPostHF5, + Vin: []types.TxInput{ + types.TxInputHTLC{Amount: 100, KeyImage: types.KeyImage{1}}, + }, + SignaturesRaw: buildSingleZCSigRaw(), + } + + err := VerifyTransactionSignatures(tx, config.TestnetForks, 250, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "HTLC") +} + func TestVerifyV2Signatures_TxHash(t *testing.T) { // Verify the known tx hash matches. tx := loadTestTx(t, "../testdata/v2_spending_tx_mixin0.hex") diff --git a/consensus/verify.go b/consensus/verify.go index 418f521..c55fd8d 100644 --- a/consensus/verify.go +++ b/consensus/verify.go @@ -152,7 +152,8 @@ func verifyV2Signatures(tx *types.Transaction, getZCRingOutputs ZCRingOutputsFn) return coreerr.E("verifyV2Signatures", fmt.Sprintf("consensus: V2 signature count %d != input count %d", len(sigEntries), len(tx.Vin)), nil) } - // Validate that ZC inputs have ZC_sig and vice versa. + // Validate that ZC inputs have ZC_sig and that ring-spending inputs use + // the ring-signature tags that match their spending model. for i, vin := range tx.Vin { switch vin.(type) { case types.TxInputZC: @@ -163,6 +164,10 @@ func verifyV2Signatures(tx *types.Transaction, getZCRingOutputs ZCRingOutputsFn) if sigEntries[i].tag != types.SigTypeNLSAG && sigEntries[i].tag != types.SigTypeVoid { return coreerr.E("verifyV2Signatures", fmt.Sprintf("consensus: input %d is to_key but signature tag is 0x%02x", i, sigEntries[i].tag), nil) } + case types.TxInputHTLC: + if sigEntries[i].tag != types.SigTypeNLSAG && sigEntries[i].tag != types.SigTypeVoid { + return coreerr.E("verifyV2Signatures", fmt.Sprintf("consensus: input %d is HTLC but signature tag is 0x%02x", i, sigEntries[i].tag), nil) + } } }