diff --git a/app/router/command/command_test.go b/app/router/command/command_test.go index 2dce1838..a7c97f92 100644 --- a/app/router/command/command_test.go +++ b/app/router/command/command_test.go @@ -45,7 +45,6 @@ func TestServiceSubscribeRoutingStats(t *testing.T) { {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"}, } errCh := make(chan error) - nextPub := make(chan struct{}) // Server goroutine go func() { @@ -77,13 +76,6 @@ func TestServiceSubscribeRoutingStats(t *testing.T) { if err := publishTestCases(); err != nil { errCh <- err } - - // Wait for next round of publishing - <-nextPub - - if err := publishTestCases(); err != nil { - errCh <- err - } }() // Client goroutine @@ -145,6 +137,92 @@ func TestServiceSubscribeRoutingStats(t *testing.T) { return nil } + if err := testRetrievingAllFields(); err != nil { + errCh <- err + } + errCh <- nil // Client passed all tests successfully + }() + + // Wait for goroutines to complete + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } +} + +func TestServiceSubscribeSubsetOfFields(t *testing.T) { + c := stats.NewChannel(&stats.ChannelConfig{ + SubscriberLimit: 1, + BufferSize: 0, + Blocking: true, + }) + common.Must(c.Start()) + defer c.Close() + + lis := bufconn.Listen(1024 * 1024) + bufDialer := func(context.Context, string) (net.Conn, error) { + return lis.Dial() + } + + testCases := []*RoutingContext{ + {InboundTag: "in", OutboundTag: "out"}, + {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"}, + {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"}, + {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"}, + {Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"}, + {Protocol: "bittorrent", OutboundTag: "blocked"}, + {User: "example@example.com", OutboundTag: "out"}, + {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"}, + } + errCh := make(chan error) + + // Server goroutine + go func() { + server := grpc.NewServer() + RegisterRoutingServiceServer(server, NewRoutingServer(nil, c)) + errCh <- server.Serve(lis) + }() + + // Publisher goroutine + go func() { + publishTestCases := func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + for { // Wait until there's one subscriber in routing stats channel + if len(c.Subscribers()) > 0 { + break + } + if ctx.Err() != nil { + return ctx.Err() + } + } + for _, tc := range testCases { + c.Publish(context.Background(), AsRoutingRoute(tc)) + time.Sleep(time.Millisecond) + } + return nil + } + + if err := publishTestCases(); err != nil { + errCh <- err + } + }() + + // Client goroutine + go func() { + defer lis.Close() + conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) + if err != nil { + errCh <- err + return + } + defer conn.Close() + client := NewRoutingServiceClient(conn) + // Test retrieving only a subset of fields testRetrievingSubsetOfFields := func() error { streamCtx, streamClose := context.WithCancel(context.Background()) @@ -156,9 +234,6 @@ func TestServiceSubscribeRoutingStats(t *testing.T) { return err } - // Send nextPub signal to start next round of publishing - close(nextPub) - for _, tc := range testCases { msg, err := stream.Recv() if err != nil { @@ -180,10 +255,6 @@ func TestServiceSubscribeRoutingStats(t *testing.T) { return nil } - - if err := testRetrievingAllFields(); err != nil { - errCh <- err - } if err := testRetrievingSubsetOfFields(); err != nil { errCh <- err }