From a28393388b09c720382881432e80ac33efb3480c Mon Sep 17 00:00:00 2001 From: saji Date: Thu, 7 Mar 2024 23:02:46 -0600 Subject: [PATCH] tests! --- broker.go | 5 +- broker_test.go | 124 ++++++++++++++++++++++++++++++++++ db_test.go | 7 +- http.go | 2 - http_test.go | 177 +++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 310 insertions(+), 5 deletions(-) create mode 100644 broker_test.go create mode 100644 http_test.go diff --git a/broker.go b/broker.go index 285f7f1..fcb901d 100644 --- a/broker.go +++ b/broker.go @@ -52,7 +52,10 @@ func (b *Broker) Unsubscribe(name string) { b.lock.Lock() defer b.lock.Unlock() b.logger.Debug("unsubscribe", "name", name) - delete(b.subs, name) + if _, ok := b.subs[name]; ok { + close(b.subs[name]) + delete(b.subs, name) + } } // Publish sends a bus event to all subscribers. It includes a sender diff --git a/broker_test.go b/broker_test.go new file mode 100644 index 0000000..2df8b66 --- /dev/null +++ b/broker_test.go @@ -0,0 +1,124 @@ +package gotelem + +import ( + "log/slog" + "os" + "reflect" + "sync" + "testing" + "time" + + "github.com/kschamplin/gotelem/skylab" +) + +func makeEvent() skylab.BusEvent { + var pkt skylab.Packet = &skylab.BmsMeasurement{ + BatteryVoltage: 12000, + AuxVoltage: 24000, + Current: 1.23, + } + return skylab.BusEvent{ + Timestamp: time.Now(), + Name: pkt.String(), + Data: pkt, + } +} + +func TestBroker(t *testing.T) { + t.Parallel() + + t.Run("test send", func(t *testing.T) { + flog := slog.New(slog.NewTextHandler(os.Stderr, nil)) + broker := NewBroker(10, flog) + + sub, err := broker.Subscribe("testSub") + if err != nil { + t.Fatalf("error subscribing: %v", err) + } + testEvent := makeEvent() + + go func() { + time.Sleep(time.Millisecond * 1) + broker.Publish("other", testEvent) + }() + + var recvEvent skylab.BusEvent + select { + case recvEvent = <-sub: + if !reflect.DeepEqual(recvEvent.Data, testEvent.Data) { + t.Fatalf("mismatched data, want %v got %v", testEvent.Data, recvEvent.Data) + } + if !testEvent.Timestamp.Equal(recvEvent.Timestamp) { + t.Fatalf("mismatched timestamp, want %v got %v", testEvent.Timestamp, recvEvent.Timestamp) + } + case <-time.After(1 * time.Second): + t.Fatalf("timeout waiting for packet") + } + + }) + t.Run("multiple broadcast", func(t *testing.T) { + flog := slog.New(slog.NewTextHandler(os.Stderr, nil)) + broker := NewBroker(10, flog) + testEvent := makeEvent() + wg := sync.WaitGroup{} + + clientFn := func(name string) { + sub, err := broker.Subscribe(name) + if err != nil { + t.Log(err) + return + } + <-sub + wg.Done() + } + + wg.Add(2) + go clientFn("client1") + go clientFn("client2") + + // yes this is stupid. otherwise we race. + time.Sleep(10 * time.Millisecond) + + broker.Publish("sender", testEvent) + + done := make(chan bool) + go func() { + wg.Wait() + done <- true + }() + select { + case <-done: + + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for clients") + } + }) + + t.Run("name collision", func(t *testing.T) { + flog := slog.New(slog.NewTextHandler(os.Stderr, nil)) + broker := NewBroker(10, flog) + _, err := broker.Subscribe("collide") + if err != nil { + t.Fatal(err) + } + _, err = broker.Subscribe("collide") + if err == nil { + t.Fatal("expected error, got nil") + } + + }) + + t.Run("unsubscribe", func(t *testing.T) { + flog := slog.New(slog.NewTextHandler(os.Stderr, nil)) + broker := NewBroker(10, flog) + ch, err := broker.Subscribe("test") + if err != nil { + t.Fatal(err) + } + broker.Unsubscribe("test") + _, ok := <-ch + if ok { + t.Fatal("expected dead channel, but channel returned result") + } + }) +} diff --git a/db_test.go b/db_test.go index 52c31ae..0e5f820 100644 --- a/db_test.go +++ b/db_test.go @@ -68,6 +68,10 @@ func MakeMockDatabase(name string) *TelemDb { if err != nil { panic(err) } + return tdb +} + +func SeedMockDatabase(tdb *TelemDb) { // seed the database now. scanner := bufio.NewScanner(strings.NewReader(exampleData)) @@ -83,8 +87,6 @@ func MakeMockDatabase(name string) *TelemDb { panic(err) } } - - return tdb } func TestTelemDb(t *testing.T) { @@ -139,6 +141,7 @@ func TestTelemDb(t *testing.T) { t.Run("test getting packets", func(t *testing.T) { tdb := MakeMockDatabase(t.Name()) + SeedMockDatabase(tdb) ctx := context.Background() f := BusEventFilter{} diff --git a/http.go b/http.go index 6b1132a..20aaef4 100644 --- a/http.go +++ b/http.go @@ -223,11 +223,9 @@ func apiV1GetPackets(tdb *TelemDb) http.HandlerFunc { lim, err := extractLimitModifier(r) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) - fmt.Print(lim) return } - // TODO: is the following check needed? var res []skylab.BusEvent res, err = tdb.GetPackets(r.Context(), *bef, lim) if err != nil { diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..ee96512 --- /dev/null +++ b/http_test.go @@ -0,0 +1,177 @@ +package gotelem + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "testing" + "time" +) + +func Test_extractBusEventFilter(t *testing.T) { + makeReq := func(path string) *http.Request { + return httptest.NewRequest(http.MethodGet, path, nil) + } + tests := []struct { + name string + req *http.Request + want *BusEventFilter + wantErr bool + }{ + { + name: "test no extractions", + req: makeReq("http://localhost/"), + want: &BusEventFilter{}, + wantErr: false, + }, + { + name: "test single name extract", + req: makeReq("http://localhost/?name=hi"), + want: &BusEventFilter{ + Names: []string{"hi"}, + }, + wantErr: false, + }, + { + name: "test multi name extract", + req: makeReq("http://localhost/?name=hi1&name=hi2"), + want: &BusEventFilter{ + Names: []string{"hi1", "hi2"}, + }, + wantErr: false, + }, + { + name: "test start time valid extract", + req: makeReq(fmt.Sprintf("http://localhost/?start=%s", url.QueryEscape(time.Unix(160000000, 0).Format(time.RFC3339)))), + want: &BusEventFilter{ + StartTime: time.Unix(160000000, 0), + }, + wantErr: false, + }, + // { + // name: "test start time invalid extract", + // req: makeReq(fmt.Sprintf("http://localhost/?start=%s", url.QueryEscape("ajlaskdj"))), + // wantErr: true, + // }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Logf("Testing URL %s", tt.req.URL.String()) + got, err := extractBusEventFilter(tt.req) + if (err != nil) != tt.wantErr { + t.Errorf("extractBusEventFilter() error = %v, wantErr %v", err, tt.wantErr) + return + } + // we have to manually compare fields because timestamps can't be deeply compared. + if !reflect.DeepEqual(got.Names, tt.want.Names) { + t.Errorf("extractBusEventFilter() Names bad = %v, want %v", got.Names, tt.want.Names) + } + if !reflect.DeepEqual(got.Indexes, tt.want.Indexes) { + t.Errorf("extractBusEventFilter() Indexes bad = %v, want %v", got.Indexes, tt.want.Indexes) + } + if !got.StartTime.Equal(tt.want.StartTime) { + t.Errorf("extractBusEventFilter() StartTime mismatch = %v, want %v", got.StartTime, tt.want.StartTime) + } + if !got.EndTime.Equal(tt.want.EndTime) { + t.Errorf("extractBusEventFilter() EndTime mismatch = %v, want %v", got.EndTime, tt.want.EndTime) + } + }) + } +} + +func Test_extractLimitModifier(t *testing.T) { + makeReq := func(path string) *http.Request { + return httptest.NewRequest(http.MethodGet, path, nil) + } + tests := []struct { + name string + req *http.Request + want *LimitOffsetModifier + wantErr bool + }{ + { + name: "test no limit/offset", + req: makeReq("http://localhost/"), + want: nil, + wantErr: false, + }, + { + name: "test limit, no offset", + req: makeReq("http://localhost/?limit=10"), + want: &LimitOffsetModifier{Limit: 10}, + wantErr: false, + }, + { + name: "test limit and offset", + req: makeReq("http://localhost/?limit=100&offset=200"), + want: &LimitOffsetModifier{Limit: 100, Offset: 200}, + wantErr: false, + }, + { + name: "test only offset", + req: makeReq("http://localhost/?&offset=200"), + want: nil, + wantErr: false, + }, + { + name: "test bad limit", + req: makeReq("http://localhost/?limit=aaaa"), + want: nil, + wantErr: true, + }, + { + name: "test good limit, bad offset", + req: makeReq("http://localhost/?limit=10&offset=jjjj"), + want: nil, + wantErr: true, + }, + + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractLimitModifier(tt.req) + if (err != nil) != tt.wantErr { + t.Errorf("extractLimitModifier() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("extractLimitModifier() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_ApiV1GetPackets(t *testing.T) { + tdb := MakeMockDatabase(t.Name()) + SeedMockDatabase(tdb) + handler := apiV1GetPackets(tdb) + + tests := []struct{ + name string + req *http.Request + statusCode int + }{ + { + name: "stationary test", + req: httptest.NewRequest(http.MethodGet, "http://localhost/", nil), + statusCode: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // construct the recorder + w := httptest.NewRecorder() + handler(w, tt.req) + + resp := w.Result() + + if tt.statusCode != resp.StatusCode { + t.Errorf("incorrect status code: expected %d got %d", tt.statusCode, resp.StatusCode) + } + + }) + } +}