diff --git a/commands_test.go b/commands_test.go index cb01f7c..5927317 100644 --- a/commands_test.go +++ b/commands_test.go @@ -7,6 +7,8 @@ package blockchain import ( "context" + "errors" + "net" "testing" "time" @@ -133,3 +135,72 @@ func TestRunChainSyncOnce_Bad_RespectsCancelledContext(t *testing.T) { require.Error(t, err) assert.Less(t, time.Since(start), 2*time.Second) } + +func TestRunChainSyncOnce_Bad_ReportsPeerIDReadError(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + done := make(chan struct{}) + defer close(done) + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + <-done + }() + + oldReadPeerID := readPeerID + readPeerID = func([]byte) (int, error) { + return 0, errors.New("peer id failed") + } + defer func() { + readPeerID = oldReadPeerID + }() + + err = runChainSyncOnce(context.Background(), nil, &config.ChainConfig{}, chain.SyncOptions{}, listener.Addr().String()) + require.Error(t, err) + assert.ErrorContains(t, err, "read peer ID") + assert.ErrorContains(t, err, "peer id failed") +} + +func TestRunChainSyncOnce_Bad_ReportsHeightLookupError(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + done := make(chan struct{}) + defer close(done) + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + <-done + }() + + oldReadPeerID := readPeerID + readPeerID = func(buf []byte) (int, error) { + copy(buf, []byte{1, 2, 3, 4, 5, 6, 7, 8}) + return len(buf), nil + } + defer func() { + readPeerID = oldReadPeerID + }() + + oldChainHeight := chainHeight + chainHeight = func(*chain.Chain) (uint64, error) { + return 0, errors.New("height failed") + } + defer func() { + chainHeight = oldChainHeight + }() + + err = runChainSyncOnce(context.Background(), nil, &config.ChainConfig{}, chain.SyncOptions{}, listener.Addr().String()) + require.Error(t, err) + assert.ErrorContains(t, err, "read local height") + assert.ErrorContains(t, err, "height failed") +} diff --git a/sync_loop.go b/sync_loop.go index abca123..136e3d0 100644 --- a/sync_loop.go +++ b/sync_loop.go @@ -22,6 +22,11 @@ import ( levin "dappco.re/go/core/p2p/node/levin" ) +var readPeerID = rand.Read +var chainHeight = func(blockchain *chain.Chain) (uint64, error) { + return blockchain.Height() +} + func runChainSyncLoop(ctx context.Context, blockchain *chain.Chain, chainConfig *config.ChainConfig, hardForks []config.HardFork, seed string) { opts := chain.SyncOptions{ VerifySignatures: false, @@ -75,10 +80,15 @@ func runChainSyncOnce(ctx context.Context, blockchain *chain.Chain, chainConfig levinConnection := levin.NewConnection(tcpConn) var peerIDBuf [8]byte - rand.Read(peerIDBuf[:]) + if _, err := readPeerID(peerIDBuf[:]); err != nil { + return coreerr.E("runChainSyncOnce", "read peer ID", err) + } peerID := binary.LittleEndian.Uint64(peerIDBuf[:]) - localHeight, _ := blockchain.Height() + localHeight, err := chainHeight(blockchain) + if err != nil { + return coreerr.E("runChainSyncOnce", "read local height", err) + } handshakeReq := p2p.HandshakeRequest{ NodeData: p2p.NodeData{