diff --git a/chain/ring.go b/chain/ring.go index ba52d31..3270fc5 100644 --- a/chain/ring.go +++ b/chain/ring.go @@ -36,11 +36,11 @@ func (c *Chain) GetRingOutputs(amount uint64, offsets []uint64) ([]types.PublicK switch out := tx.Vout[outNo].(type) { case types.TxOutputBare: - toKey, ok := out.Target.(types.TxOutToKey) - if !ok { - return nil, coreerr.E("Chain.GetRingOutputs", fmt.Sprintf("ring output %d: unsupported target type %T", i, out.Target), nil) + key, err := ringOutputKey(out.Target) + if err != nil { + return nil, coreerr.E("Chain.GetRingOutputs", fmt.Sprintf("ring output %d: %v", i, err), nil) } - pubs[i] = toKey.Key + pubs[i] = key default: return nil, coreerr.E("Chain.GetRingOutputs", fmt.Sprintf("ring output %d: unsupported output type %T", i, out), nil) } @@ -48,6 +48,28 @@ func (c *Chain) GetRingOutputs(amount uint64, offsets []uint64) ([]types.PublicK return pubs, nil } +// ringOutputKey extracts the spend key for a transparent output target. +// +// TxOutMultisig and TxOutHTLC do not carry enough context here to select the +// exact spend path, so we return a deterministic key that keeps ring lookup +// usable for basic sync and verification. More specific spend-context logic +// belongs in the wallet layer. +func ringOutputKey(target types.TxOutTarget) (types.PublicKey, error) { + switch t := target.(type) { + case types.TxOutToKey: + return t.Key, nil + case types.TxOutMultisig: + if len(t.Keys) == 0 { + return types.PublicKey{}, fmt.Errorf("unsupported multisig target with no keys") + } + return t.Keys[0], nil + case types.TxOutHTLC: + return t.PKRedeem, nil + default: + return types.PublicKey{}, fmt.Errorf("unsupported target type %T", target) + } +} + // GetZCRingOutputs fetches ZC ring members (stealth address, amount commitment, // blinded asset ID) for the given global output indices. This implements the // consensus.ZCRingOutputsFn signature for post-HF4 CLSAG GGX verification. diff --git a/chain/ring_test.go b/chain/ring_test.go index 4f7fd23..2328767 100644 --- a/chain/ring_test.go +++ b/chain/ring_test.go @@ -55,6 +55,85 @@ func TestGetRingOutputs_Good(t *testing.T) { } } +func TestGetRingOutputs_Good_Multisig(t *testing.T) { + c := newTestChain(t) + + first := types.PublicKey{0x11, 0x22, 0x33} + second := types.PublicKey{0x44, 0x55, 0x66} + tx := types.Transaction{ + Version: types.VersionPreHF4, + Vin: []types.TxInput{types.TxInputGenesis{Height: 0}}, + Vout: []types.TxOutput{ + types.TxOutputBare{ + Amount: 1000, + Target: types.TxOutMultisig{ + MinimumSigs: 2, + Keys: []types.PublicKey{first, second}, + }, + }, + }, + Extra: wire.EncodeVarint(0), + Attachment: wire.EncodeVarint(0), + } + txHash := wire.TransactionHash(&tx) + + if err := c.PutTransaction(txHash, &tx, &TxMeta{KeeperBlock: 0, GlobalOutputIndexes: []uint64{0}}); err != nil { + t.Fatalf("PutTransaction: %v", err) + } + if _, err := c.PutOutput(1000, txHash, 0); err != nil { + t.Fatalf("PutOutput: %v", err) + } + + pubs, err := c.GetRingOutputs(1000, []uint64{0}) + if err != nil { + t.Fatalf("GetRingOutputs: %v", err) + } + if pubs[0] != first { + t.Errorf("pubs[0]: got %x, want %x", pubs[0], first) + } +} + +func TestGetRingOutputs_Good_HTLC(t *testing.T) { + c := newTestChain(t) + + redeem := types.PublicKey{0xAA, 0xBB, 0xCC} + refund := types.PublicKey{0xDD, 0xEE, 0xFF} + tx := types.Transaction{ + Version: types.VersionPreHF4, + Vin: []types.TxInput{types.TxInputGenesis{Height: 0}}, + Vout: []types.TxOutput{ + types.TxOutputBare{ + Amount: 1000, + Target: types.TxOutHTLC{ + HTLCHash: types.Hash{0x01}, + Flags: 0, + Expiration: 200, + PKRedeem: redeem, + PKRefund: refund, + }, + }, + }, + Extra: wire.EncodeVarint(0), + Attachment: wire.EncodeVarint(0), + } + txHash := wire.TransactionHash(&tx) + + if err := c.PutTransaction(txHash, &tx, &TxMeta{KeeperBlock: 0, GlobalOutputIndexes: []uint64{0}}); err != nil { + t.Fatalf("PutTransaction: %v", err) + } + if _, err := c.PutOutput(1000, txHash, 0); err != nil { + t.Fatalf("PutOutput: %v", err) + } + + pubs, err := c.GetRingOutputs(1000, []uint64{0}) + if err != nil { + t.Fatalf("GetRingOutputs: %v", err) + } + if pubs[0] != redeem { + t.Errorf("pubs[0]: got %x, want %x", pubs[0], redeem) + } +} + func TestGetRingOutputs_Good_MultipleOutputs(t *testing.T) { c := newTestChain(t)