diff --git a/wire/transaction.go b/wire/transaction.go index e53daed..583b75f 100644 --- a/wire/transaction.go +++ b/wire/transaction.go @@ -289,6 +289,20 @@ func encodeOutputsV1(enc *Encoder, vout []types.TxOutput) { enc.WriteVariantTag(types.TargetTypeToKey) enc.WriteBlob32((*[32]byte)(&tgt.Key)) enc.WriteUint8(tgt.MixAttr) + case types.TxOutMultisig: + enc.WriteVariantTag(types.TargetTypeMultisig) + enc.WriteVarint(tgt.MinimumSigs) + enc.WriteVarint(uint64(len(tgt.Keys))) + for i := range tgt.Keys { + enc.WriteBlob32((*[32]byte)(&tgt.Keys[i])) + } + case types.TxOutHTLC: + enc.WriteVariantTag(types.TargetTypeHTLC) + enc.WriteBlob32((*[32]byte)(&tgt.HTLCHash)) + enc.WriteUint8(tgt.Flags) + enc.WriteVarint(tgt.Expiration) + enc.WriteBlob32((*[32]byte)(&tgt.PKRedeem)) + enc.WriteBlob32((*[32]byte)(&tgt.PKRefund)) } } } @@ -313,6 +327,25 @@ func decodeOutputsV1(dec *Decoder) []types.TxOutput { dec.ReadBlob32((*[32]byte)(&tgt.Key)) tgt.MixAttr = dec.ReadUint8() out.Target = tgt + case types.TargetTypeMultisig: + var tgt types.TxOutMultisig + tgt.MinimumSigs = dec.ReadVarint() + keyCount := dec.ReadVarint() + if keyCount > 0 && dec.Err() == nil { + tgt.Keys = make([]types.PublicKey, keyCount) + for j := uint64(0); j < keyCount; j++ { + dec.ReadBlob32((*[32]byte)(&tgt.Keys[j])) + } + } + out.Target = tgt + case types.TargetTypeHTLC: + var tgt types.TxOutHTLC + dec.ReadBlob32((*[32]byte)(&tgt.HTLCHash)) + tgt.Flags = dec.ReadUint8() + tgt.Expiration = dec.ReadVarint() + dec.ReadBlob32((*[32]byte)(&tgt.PKRedeem)) + dec.ReadBlob32((*[32]byte)(&tgt.PKRefund)) + out.Target = tgt default: dec.err = fmt.Errorf("wire: unsupported target tag 0x%02x", tag) return vout @@ -335,6 +368,20 @@ func encodeOutputsV2(enc *Encoder, vout []types.TxOutput) { enc.WriteVariantTag(types.TargetTypeToKey) enc.WriteBlob32((*[32]byte)(&tgt.Key)) enc.WriteUint8(tgt.MixAttr) + case types.TxOutMultisig: + enc.WriteVariantTag(types.TargetTypeMultisig) + enc.WriteVarint(tgt.MinimumSigs) + enc.WriteVarint(uint64(len(tgt.Keys))) + for i := range tgt.Keys { + enc.WriteBlob32((*[32]byte)(&tgt.Keys[i])) + } + case types.TxOutHTLC: + enc.WriteVariantTag(types.TargetTypeHTLC) + enc.WriteBlob32((*[32]byte)(&tgt.HTLCHash)) + enc.WriteUint8(tgt.Flags) + enc.WriteVarint(tgt.Expiration) + enc.WriteBlob32((*[32]byte)(&tgt.PKRedeem)) + enc.WriteBlob32((*[32]byte)(&tgt.PKRefund)) } case types.TxOutputZarcanum: enc.WriteBlob32((*[32]byte)(&v.StealthAddress)) @@ -363,12 +410,35 @@ func decodeOutputsV2(dec *Decoder) []types.TxOutput { var out types.TxOutputBare out.Amount = dec.ReadVarint() targetTag := dec.ReadVariantTag() - if targetTag == types.TargetTypeToKey { + if dec.Err() != nil { + return vout + } + switch targetTag { + case types.TargetTypeToKey: var tgt types.TxOutToKey dec.ReadBlob32((*[32]byte)(&tgt.Key)) tgt.MixAttr = dec.ReadUint8() out.Target = tgt - } else { + case types.TargetTypeMultisig: + var tgt types.TxOutMultisig + tgt.MinimumSigs = dec.ReadVarint() + keyCount := dec.ReadVarint() + if keyCount > 0 && dec.Err() == nil { + tgt.Keys = make([]types.PublicKey, keyCount) + for j := uint64(0); j < keyCount; j++ { + dec.ReadBlob32((*[32]byte)(&tgt.Keys[j])) + } + } + out.Target = tgt + case types.TargetTypeHTLC: + var tgt types.TxOutHTLC + dec.ReadBlob32((*[32]byte)(&tgt.HTLCHash)) + tgt.Flags = dec.ReadUint8() + tgt.Expiration = dec.ReadVarint() + dec.ReadBlob32((*[32]byte)(&tgt.PKRedeem)) + dec.ReadBlob32((*[32]byte)(&tgt.PKRefund)) + out.Target = tgt + default: dec.err = fmt.Errorf("wire: unsupported target tag 0x%02x", targetTag) return vout } diff --git a/wire/transaction_test.go b/wire/transaction_test.go index 602a34a..2b4f13c 100644 --- a/wire/transaction_test.go +++ b/wire/transaction_test.go @@ -590,3 +590,243 @@ func TestMultisigInputRoundTrip_Good(t *testing.T) { t.Errorf("round-trip mismatch:\n got: %x\n want: %x", rtBuf.Bytes(), buf.Bytes()) } } + +func TestMultisigTargetV1RoundTrip_Good(t *testing.T) { + tx := types.Transaction{ + Version: types.VersionPreHF4, + Vin: []types.TxInput{types.TxInputGenesis{Height: 1}}, + Vout: []types.TxOutput{types.TxOutputBare{ + Amount: 5000, + Target: types.TxOutMultisig{ + MinimumSigs: 2, + Keys: []types.PublicKey{{0x01}, {0x02}, {0x03}}, + }, + }}, + Extra: EncodeVarint(0), + } + + var buf bytes.Buffer + enc := NewEncoder(&buf) + EncodeTransactionPrefix(enc, &tx) + if enc.Err() != nil { + t.Fatalf("encode error: %v", enc.Err()) + } + + dec := NewDecoder(bytes.NewReader(buf.Bytes())) + got := DecodeTransactionPrefix(dec) + if dec.Err() != nil { + t.Fatalf("decode error: %v", dec.Err()) + } + + bare, ok := got.Vout[0].(types.TxOutputBare) + if !ok { + t.Fatalf("vout[0] type: got %T, want TxOutputBare", got.Vout[0]) + } + msig, ok := bare.Target.(types.TxOutMultisig) + if !ok { + t.Fatalf("target type: got %T, want TxOutMultisig", bare.Target) + } + if msig.MinimumSigs != 2 { + t.Errorf("MinimumSigs: got %d, want 2", msig.MinimumSigs) + } + if len(msig.Keys) != 3 { + t.Errorf("Keys count: got %d, want 3", len(msig.Keys)) + } + + // Byte-level round-trip. + var rtBuf bytes.Buffer + enc2 := NewEncoder(&rtBuf) + EncodeTransactionPrefix(enc2, &got) + if enc2.Err() != nil { + t.Fatalf("re-encode error: %v", enc2.Err()) + } + if !bytes.Equal(rtBuf.Bytes(), buf.Bytes()) { + t.Errorf("round-trip mismatch:\n got: %x\n want: %x", rtBuf.Bytes(), buf.Bytes()) + } +} + +func TestHTLCTargetV1RoundTrip_Good(t *testing.T) { + tx := types.Transaction{ + Version: types.VersionPreHF4, + Vin: []types.TxInput{types.TxInputGenesis{Height: 1}}, + Vout: []types.TxOutput{types.TxOutputBare{ + Amount: 7000, + Target: types.TxOutHTLC{ + HTLCHash: types.Hash{0xCC}, + Flags: 1, // RIPEMD160 + Expiration: 20000, + PKRedeem: types.PublicKey{0xDD}, + PKRefund: types.PublicKey{0xEE}, + }, + }}, + Extra: EncodeVarint(0), + } + + var buf bytes.Buffer + enc := NewEncoder(&buf) + EncodeTransactionPrefix(enc, &tx) + if enc.Err() != nil { + t.Fatalf("encode error: %v", enc.Err()) + } + + dec := NewDecoder(bytes.NewReader(buf.Bytes())) + got := DecodeTransactionPrefix(dec) + if dec.Err() != nil { + t.Fatalf("decode error: %v", dec.Err()) + } + + bare, ok := got.Vout[0].(types.TxOutputBare) + if !ok { + t.Fatalf("vout[0] type: got %T, want TxOutputBare", got.Vout[0]) + } + htlc, ok := bare.Target.(types.TxOutHTLC) + if !ok { + t.Fatalf("target type: got %T, want TxOutHTLC", bare.Target) + } + if htlc.HTLCHash[0] != 0xCC { + t.Errorf("HTLCHash[0]: got 0x%02x, want 0xCC", htlc.HTLCHash[0]) + } + if htlc.Flags != 1 { + t.Errorf("Flags: got %d, want 1", htlc.Flags) + } + if htlc.Expiration != 20000 { + t.Errorf("Expiration: got %d, want 20000", htlc.Expiration) + } + if htlc.PKRedeem[0] != 0xDD { + t.Errorf("PKRedeem[0]: got 0x%02x, want 0xDD", htlc.PKRedeem[0]) + } + if htlc.PKRefund[0] != 0xEE { + t.Errorf("PKRefund[0]: got 0x%02x, want 0xEE", htlc.PKRefund[0]) + } + + // Byte-level round-trip. + var rtBuf bytes.Buffer + enc2 := NewEncoder(&rtBuf) + EncodeTransactionPrefix(enc2, &got) + if enc2.Err() != nil { + t.Fatalf("re-encode error: %v", enc2.Err()) + } + if !bytes.Equal(rtBuf.Bytes(), buf.Bytes()) { + t.Errorf("round-trip mismatch:\n got: %x\n want: %x", rtBuf.Bytes(), buf.Bytes()) + } +} + +func TestMultisigTargetV2RoundTrip_Good(t *testing.T) { + tx := types.Transaction{ + Version: types.VersionPostHF4, + Vin: []types.TxInput{types.TxInputGenesis{Height: 1}}, + Vout: []types.TxOutput{types.TxOutputBare{ + Amount: 5000, + Target: types.TxOutMultisig{ + MinimumSigs: 2, + Keys: []types.PublicKey{{0x01}, {0x02}}, + }, + }}, + Extra: EncodeVarint(0), + } + + var buf bytes.Buffer + enc := NewEncoder(&buf) + EncodeTransactionPrefix(enc, &tx) + if enc.Err() != nil { + t.Fatalf("encode error: %v", enc.Err()) + } + + dec := NewDecoder(bytes.NewReader(buf.Bytes())) + got := DecodeTransactionPrefix(dec) + if dec.Err() != nil { + t.Fatalf("decode error: %v", dec.Err()) + } + + bare, ok := got.Vout[0].(types.TxOutputBare) + if !ok { + t.Fatalf("vout[0] type: got %T, want TxOutputBare", got.Vout[0]) + } + msig, ok := bare.Target.(types.TxOutMultisig) + if !ok { + t.Fatalf("target type: got %T, want TxOutMultisig", bare.Target) + } + if msig.MinimumSigs != 2 { + t.Errorf("MinimumSigs: got %d, want 2", msig.MinimumSigs) + } + if len(msig.Keys) != 2 { + t.Errorf("Keys count: got %d, want 2", len(msig.Keys)) + } + + // Byte-level round-trip. + var rtBuf bytes.Buffer + enc2 := NewEncoder(&rtBuf) + EncodeTransactionPrefix(enc2, &got) + if enc2.Err() != nil { + t.Fatalf("re-encode error: %v", enc2.Err()) + } + if !bytes.Equal(rtBuf.Bytes(), buf.Bytes()) { + t.Errorf("round-trip mismatch:\n got: %x\n want: %x", rtBuf.Bytes(), buf.Bytes()) + } +} + +func TestHTLCTargetV2RoundTrip_Good(t *testing.T) { + tx := types.Transaction{ + Version: types.VersionPostHF4, + Vin: []types.TxInput{types.TxInputGenesis{Height: 1}}, + Vout: []types.TxOutput{types.TxOutputBare{ + Amount: 7000, + Target: types.TxOutHTLC{ + HTLCHash: types.Hash{0xCC}, + Flags: 0, // SHA256 + Expiration: 15000, + PKRedeem: types.PublicKey{0xDD}, + PKRefund: types.PublicKey{0xEE}, + }, + }}, + Extra: EncodeVarint(0), + } + + var buf bytes.Buffer + enc := NewEncoder(&buf) + EncodeTransactionPrefix(enc, &tx) + if enc.Err() != nil { + t.Fatalf("encode error: %v", enc.Err()) + } + + dec := NewDecoder(bytes.NewReader(buf.Bytes())) + got := DecodeTransactionPrefix(dec) + if dec.Err() != nil { + t.Fatalf("decode error: %v", dec.Err()) + } + + bare, ok := got.Vout[0].(types.TxOutputBare) + if !ok { + t.Fatalf("vout[0] type: got %T, want TxOutputBare", got.Vout[0]) + } + htlc, ok := bare.Target.(types.TxOutHTLC) + if !ok { + t.Fatalf("target type: got %T, want TxOutHTLC", bare.Target) + } + if htlc.HTLCHash[0] != 0xCC { + t.Errorf("HTLCHash[0]: got 0x%02x, want 0xCC", htlc.HTLCHash[0]) + } + if htlc.Flags != 0 { + t.Errorf("Flags: got %d, want 0", htlc.Flags) + } + if htlc.Expiration != 15000 { + t.Errorf("Expiration: got %d, want 15000", htlc.Expiration) + } + if htlc.PKRedeem[0] != 0xDD { + t.Errorf("PKRedeem[0]: got 0x%02x, want 0xDD", htlc.PKRedeem[0]) + } + if htlc.PKRefund[0] != 0xEE { + t.Errorf("PKRefund[0]: got 0x%02x, want 0xEE", htlc.PKRefund[0]) + } + + // Byte-level round-trip. + var rtBuf bytes.Buffer + enc2 := NewEncoder(&rtBuf) + EncodeTransactionPrefix(enc2, &got) + if enc2.Err() != nil { + t.Fatalf("re-encode error: %v", enc2.Err()) + } + if !bytes.Equal(rtBuf.Bytes(), buf.Bytes()) { + t.Errorf("round-trip mismatch:\n got: %x\n want: %x", rtBuf.Bytes(), buf.Bytes()) + } +}