diff --git a/wallet/wallet.go b/wallet/wallet.go index f6be5cb..6e838bf 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -122,17 +122,23 @@ func (w *Wallet) scanTx(tx *types.Transaction, blockHeight uint64) error { // Check key images for spend detection. for _, vin := range tx.Vin { - toKey, ok := vin.(types.TxInputToKey) - if !ok { + var keyImage types.KeyImage + switch v := vin.(type) { + case types.TxInputToKey: + keyImage = v.KeyImage + case types.TxInputHTLC: + keyImage = v.KeyImage + default: continue } + // Try to mark any matching transfer as spent. - tr, err := getTransfer(w.store, toKey.KeyImage) + tr, err := getTransfer(w.store, keyImage) if err != nil { continue // not our transfer } if !tr.Spent { - markTransferSpent(w.store, toKey.KeyImage, blockHeight) + markTransferSpent(w.store, keyImage, blockHeight) } } diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go index adcb338..c113d1a 100644 --- a/wallet/wallet_test.go +++ b/wallet/wallet_test.go @@ -12,11 +12,11 @@ package wallet import ( "testing" - store "dappco.re/go/core/store" "dappco.re/go/core/blockchain/chain" "dappco.re/go/core/blockchain/crypto" "dappco.re/go/core/blockchain/types" "dappco.re/go/core/blockchain/wire" + store "dappco.re/go/core/store" ) func makeTestBlock(t *testing.T, height uint64, prevHash types.Hash, @@ -133,3 +133,55 @@ func TestWalletTransfers(t *testing.T) { t.Fatalf("got %d transfers, want 1", len(transfers)) } } + +func TestWalletScanTxMarksHTLCSpend(t *testing.T) { + s, err := store.New(":memory:") + if err != nil { + t.Fatal(err) + } + defer s.Close() + + acc, err := GenerateAccount() + if err != nil { + t.Fatal(err) + } + + ki := types.KeyImage{0x42} + if err := putTransfer(s, &Transfer{ + KeyImage: ki, + Amount: 100, + BlockHeight: 1, + }); err != nil { + t.Fatal(err) + } + + w := &Wallet{ + store: s, + scanner: NewV1Scanner(acc), + } + + tx := &types.Transaction{ + Version: types.VersionPreHF4, + Vin: []types.TxInput{ + types.TxInputHTLC{ + Amount: 100, + KeyImage: ki, + }, + }, + } + + if err := w.scanTx(tx, 10); err != nil { + t.Fatal(err) + } + + got, err := getTransfer(s, ki) + if err != nil { + t.Fatal(err) + } + if !got.Spent { + t.Fatal("expected HTLC spend to be marked spent") + } + if got.SpentHeight != 10 { + t.Fatalf("spent height = %d, want 10", got.SpentHeight) + } +}