diff --git a/pkg/runtime/bulk_subscriber.go b/pkg/runtime/bulk_subscriber.go index b70aa3875ef93494ccff1c51bb9ea8eab3da0418..054e8c1b5334e072faa44cb4ac1c6b95425e5993 100644 --- a/pkg/runtime/bulk_subscriber.go +++ b/pkg/runtime/bulk_subscriber.go @@ -382,7 +382,7 @@ func (a *DaprRuntime) publishBulkMessageHTTP(ctx context.Context, bulkSubCallDat if err != nil { bscData.bulkSubDiag.statusWiseDiag[string(pubsub.Retry)] += int64(len(rawMsgEntries)) bscData.bulkSubDiag.elapsed = elapsed - populateBulkSubscribeResponsesWithError(psm, bscData.bulkResponses, err) + populateBulkSubscribeResponsesWithError(psm, bulkResponses, err) return fmt.Errorf("error from app channel while sending pub/sub event to app: %w", err) } defer resp.Close() @@ -402,7 +402,7 @@ func (a *DaprRuntime) publishBulkMessageHTTP(ctx context.Context, bulkSubCallDat if err != nil { bscData.bulkSubDiag.statusWiseDiag[string(pubsub.Retry)] += int64(len(rawMsgEntries)) bscData.bulkSubDiag.elapsed = elapsed - populateBulkSubscribeResponsesWithError(psm, bscData.bulkResponses, err) + populateBulkSubscribeResponsesWithError(psm, bulkResponses, err) return fmt.Errorf("failed unmarshalling app response for bulk subscribe: %w", err) } @@ -474,7 +474,7 @@ func (a *DaprRuntime) publishBulkMessageHTTP(ctx context.Context, bulkSubCallDat log.Errorf("Non-retriable error returned from app while processing bulk pub/sub event. status code returned: %v", statusCode) bscData.bulkSubDiag.statusWiseDiag[string(pubsub.Drop)] += int64(len(rawMsgEntries)) bscData.bulkSubDiag.elapsed = elapsed - populateBulkSubscribeResponsesWithError(psm, bscData.bulkResponses, nil) + populateBulkSubscribeResponsesWithError(psm, bulkResponses, nil) return nil } @@ -484,7 +484,7 @@ func (a *DaprRuntime) publishBulkMessageHTTP(ctx context.Context, bulkSubCallDat log.Warn(retriableErrorStr) bscData.bulkSubDiag.statusWiseDiag[string(pubsub.Retry)] += int64(len(rawMsgEntries)) bscData.bulkSubDiag.elapsed = elapsed - populateBulkSubscribeResponsesWithError(psm, bscData.bulkResponses, retriableError) + populateBulkSubscribeResponsesWithError(psm, bulkResponses, retriableError) return retriableError } @@ -628,7 +628,7 @@ func (a *DaprRuntime) publishBulkMessageGRPC(ctx context.Context, bulkSubCallDat log.Warnf("non-retriable error returned from app while processing bulk pub/sub event: %s", err) bscData.bulkSubDiag.statusWiseDiag[string(pubsub.Drop)] += int64(len(psm.pubSubMessages)) bscData.bulkSubDiag.elapsed = elapsed - populateBulkSubscribeResponsesWithError(psm, bscData.bulkResponses, nil) + populateBulkSubscribeResponsesWithError(psm, bulkResponses, nil) return nil } @@ -636,9 +636,9 @@ func (a *DaprRuntime) publishBulkMessageGRPC(ctx context.Context, bulkSubCallDat log.Debug(err) bscData.bulkSubDiag.statusWiseDiag[string(pubsub.Retry)] += int64(len(psm.pubSubMessages)) bscData.bulkSubDiag.elapsed = elapsed - populateBulkSubscribeResponsesWithError(psm, bscData.bulkResponses, err) + populateBulkSubscribeResponsesWithError(psm, bulkResponses, err) // on error from application, return error for redelivery of event - return nil + return err } hasAnyError := false diff --git a/pkg/runtime/bulk_subscriber_test.go b/pkg/runtime/bulk_subscriber_test.go index 25e7faf4baf7d42bb9526bf57730d7184239c31a..a2226547571cb4ca1cfe8bc538017538ee998418 100644 --- a/pkg/runtime/bulk_subscriber_test.go +++ b/pkg/runtime/bulk_subscriber_test.go @@ -3,6 +3,7 @@ package runtime import ( "encoding/json" + "errors" "fmt" "net" "strings" @@ -14,6 +15,8 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/dapr/components-contrib/pubsub" @@ -148,7 +151,7 @@ func TestBulkSubscribe(t *testing.T) { Topic: "topic0", Data: []byte(`{"orderId":"1"}`), }) - assert.NoError(t, err) + assert.Error(t, err) pubSub, ok := rt.compStore.GetPubSub(testBulkSubscribePubsub) require.True(t, ok) pubsubIns := pubSub.Component.(*mockSubscribePubSub) @@ -196,7 +199,7 @@ func TestBulkSubscribe(t *testing.T) { Topic: "topic0", Data: []byte(order), }) - assert.Nil(t, err) + assert.Error(t, err) pubSub, ok := rt.compStore.GetPubSub(testBulkSubscribePubsub) require.True(t, ok) pubsubIns := pubSub.Component.(*mockSubscribePubSub) @@ -210,9 +213,10 @@ func TestBulkSubscribe(t *testing.T) { t.Run("bulk Subscribe multiple Messages at once for cloud events", func(t *testing.T) { rt := NewTestDaprRuntime(modes.StandaloneMode) defer stopRuntime(t, rt) + ms := &mockSubscribePubSub{} rt.pubSubRegistry.RegisterComponent( func(_ logger.Logger) pubsub.PubSub { - return &mockSubscribePubSub{} + return ms }, "mockPubSub", ) @@ -233,19 +237,24 @@ func TestBulkSubscribe(t *testing.T) { mockAppChannel.Init() rt.appChannel = mockAppChannel mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), matchDaprRequestMethod("dapr/subscribe")).Return(fakeResp, nil) - mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), mock.Anything).Return(fakeResp, nil) + fakeResp1 := invokev1.NewInvokeMethodResponse(200, "OK", nil) + defer fakeResp1.Close() + mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), mock.Anything).Return(fakeResp1, nil) require.NoError(t, rt.initPubSub(pubsubComponent)) rt.startSubscriptions() msgArr := getBulkMessageEntries(2) - _, err := rt.BulkPublish(&pubsub.BulkPublishRequest{ + rt.BulkPublish(&pubsub.BulkPublishRequest{ PubsubName: testBulkSubscribePubsub, Topic: "topic0", Entries: msgArr, }) - assert.Nil(t, err) + + assert.Equal(t, 2, len(ms.GetBulkResponse().Statuses)) + assert.Error(t, ms.GetBulkResponse().Error) + assert.Nil(t, assertItemExistsOnce(ms.GetBulkResponse().Statuses, "1111111a", "2222222b")) pubSub, ok := rt.compStore.GetPubSub(testBulkSubscribePubsub) require.True(t, ok) @@ -256,6 +265,83 @@ func TestBulkSubscribe(t *testing.T) { mockAppChannel.AssertNumberOfCalls(t, "InvokeMethod", 2) assert.Contains(t, string(reqs["orders"]), `"event":`+order1) assert.Contains(t, string(reqs["orders"]), `"event":`+order2) + + fakeResp2 := invokev1.NewInvokeMethodResponse(404, "OK", nil) + defer fakeResp2.Close() + mockAppChannel1 := new(channelt.MockAppChannel) + mockAppChannel1.Init() + rt.appChannel = mockAppChannel1 + mockAppChannel1.On("InvokeMethod", mock.MatchedBy(matchContextInterface), mock.Anything).Return(fakeResp2, nil) + + msgArr = getBulkMessageEntries(3) + + rt.BulkPublish(&pubsub.BulkPublishRequest{ + PubsubName: testBulkSubscribePubsub, + Topic: "topic0", + Entries: msgArr, + }) + + assert.Equal(t, 3, len(ms.GetBulkResponse().Statuses)) + assert.Nil(t, ms.GetBulkResponse().Error) + assert.Nil(t, assertItemExistsOnce(ms.GetBulkResponse().Statuses, "1111111a", "2222222b", "333333c")) + + assert.Equal(t, 2, pubsubIns.bulkPubCount["topic0"]) + assert.True(t, pubsubIns.isBulkSubscribe) + reqs = mockAppChannel1.GetInvokedRequest() + mockAppChannel1.AssertNumberOfCalls(t, "InvokeMethod", 1) + assert.Contains(t, string(reqs["orders"]), `"event":`+order1) + assert.Contains(t, string(reqs["orders"]), `"event":`+order2) + assert.Contains(t, string(reqs["orders"]), `"event":`+order3) + + fakeResp3 := invokev1.NewInvokeMethodResponse(400, "OK", nil) + defer fakeResp3.Close() + mockAppChannel2 := new(channelt.MockAppChannel) + mockAppChannel2.Init() + rt.appChannel = mockAppChannel2 + mockAppChannel2.On("InvokeMethod", mock.MatchedBy(matchContextInterface), mock.Anything).Return(fakeResp3, nil) + + msgArr = getBulkMessageEntries(4) + + rt.BulkPublish(&pubsub.BulkPublishRequest{ + PubsubName: testBulkSubscribePubsub, + Topic: "topic0", + Entries: msgArr, + }) + + assert.Equal(t, 4, len(ms.GetBulkResponse().Statuses)) + assert.Error(t, ms.GetBulkResponse().Error) + assert.Nil(t, assertItemExistsOnce(ms.GetBulkResponse().Statuses, "1111111a", "2222222b", "333333c", "4444444d")) + + assert.Equal(t, 3, pubsubIns.bulkPubCount["topic0"]) + assert.True(t, pubsubIns.isBulkSubscribe) + reqs = mockAppChannel2.GetInvokedRequest() + mockAppChannel2.AssertNumberOfCalls(t, "InvokeMethod", 1) + assert.Contains(t, string(reqs["orders"]), `"event":`+order1) + assert.Contains(t, string(reqs["orders"]), `"event":`+order2) + assert.Contains(t, string(reqs["orders"]), `"event":`+order3) + assert.Contains(t, string(reqs["orders"]), `"event":`+order4) + + mockAppChannel3 := new(channelt.MockAppChannel) + mockAppChannel3.Init() + rt.appChannel = mockAppChannel3 + mockAppChannel3.On("InvokeMethod", mock.MatchedBy(matchContextInterface), mock.Anything).Return(nil, errors.New("Mock error")) + msgArr = getBulkMessageEntries(1) + + rt.BulkPublish(&pubsub.BulkPublishRequest{ + PubsubName: testBulkSubscribePubsub, + Topic: "topic0", + Entries: msgArr, + }) + + assert.Equal(t, 1, len(ms.GetBulkResponse().Statuses)) + assert.Error(t, ms.GetBulkResponse().Error) + assert.Nil(t, assertItemExistsOnce(ms.GetBulkResponse().Statuses, "1111111a")) + + assert.Equal(t, 4, pubsubIns.bulkPubCount["topic0"]) + assert.True(t, pubsubIns.isBulkSubscribe) + reqs = mockAppChannel3.GetInvokedRequest() + mockAppChannel3.AssertNumberOfCalls(t, "InvokeMethod", 1) + assert.Contains(t, string(reqs["orders"]), `"event":`+order1) }) t.Run("bulk Subscribe events on different paths", func(t *testing.T) { @@ -715,10 +801,11 @@ func TestBulkSubscribeGRPC(t *testing.T) { port, _ := freeport.GetFreePort() rt := NewTestDaprRuntimeWithProtocol(modes.StandaloneMode, string(GRPCProtocol), port) defer stopRuntime(t, rt) + ms := &mockSubscribePubSub{} rt.pubSubRegistry.RegisterComponent( func(_ logger.Logger) pubsub.PubSub { - return &mockSubscribePubSub{} + return ms }, "mockPubSub", ) @@ -780,6 +867,10 @@ func TestBulkSubscribeGRPC(t *testing.T) { Topic: "topic0", Entries: msgArr, }) + assert.Equal(t, 2, len(ms.GetBulkResponse().Statuses)) + assert.Nil(t, ms.GetBulkResponse().Error) + assert.Nil(t, assertItemExistsOnce(ms.GetBulkResponse().Statuses, "1111111a", "2222222b")) + assert.Nil(t, err) pubSub, ok := rt.compStore.GetPubSub(testBulkSubscribePubsub) require.True(t, ok) @@ -794,6 +885,27 @@ func TestBulkSubscribeGRPC(t *testing.T) { assert.Contains(t, string(mockServer.RequestsReceived["orders"].GetEntries()[0].GetBytes()), `{"orderId":"1"}`) assert.Contains(t, string(mockServer.RequestsReceived["orders"].GetEntries()[1].GetBytes()), `{"orderId":"2"}`) assert.True(t, verifyBulkSubscribeResponses(expectedResponse, pubsubIns.bulkReponse.Statuses)) + + mockServer.BulkResponsePerPath = nil + mockServer.Error = status.Error(codes.Unimplemented, "method not implemented") + rt.BulkPublish(&pubsub.BulkPublishRequest{ + PubsubName: testBulkSubscribePubsub, + Topic: "topic0", + Entries: msgArr, + }) + assert.Equal(t, 2, len(ms.GetBulkResponse().Statuses)) + assert.Nil(t, ms.GetBulkResponse().Error) + assert.Nil(t, assertItemExistsOnce(ms.GetBulkResponse().Statuses, "1111111a", "2222222b")) + + mockServer.Error = status.Error(codes.Unknown, "unknown error") + rt.BulkPublish(&pubsub.BulkPublishRequest{ + PubsubName: testBulkSubscribePubsub, + Topic: "topic0", + Entries: msgArr, + }) + assert.Equal(t, 2, len(ms.GetBulkResponse().Statuses)) + assert.Error(t, ms.GetBulkResponse().Error) + assert.Nil(t, assertItemExistsOnce(ms.GetBulkResponse().Statuses, "1111111a", "2222222b")) }) t.Run("GRPC - bulk Subscribe cloud event Message on different paths and verify response", func(t *testing.T) { @@ -1322,3 +1434,19 @@ func verifyBulkSubscribeRequest(expectedData []string, expectedExtension Expecte } return true } + +func assertItemExistsOnce(collection []pubsub.BulkSubscribeResponseEntry, items ...string) error { + count := 0 + for _, item := range items { + for _, c := range collection { + if c.EntryId == item { + count++ + } + } + if count != 1 { + return fmt.Errorf("item %s not found or found more than once", item) + } + count = 0 + } + return nil +} diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 1e6496654bfdb6dd84ff191b5dee852a38fe0c8e..9684986d269a845f5fff5fbf2c5b3d506384bcaa 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -3984,6 +3984,7 @@ func (m *mockSubscribePubSub) Init(ctx context.Context, metadata pubsub.Metadata // Publish is a mock publish method. Immediately trigger handler if topic is subscribed. func (m *mockSubscribePubSub) Publish(ctx context.Context, req *pubsub.PublishRequest) error { m.pubCount[req.Topic]++ + var err error if handler, ok := m.handlers[req.Topic]; ok { pubsubMsg := &pubsub.NewMessage{ Data: req.Data, @@ -4001,9 +4002,9 @@ func (m *mockSubscribePubSub) Publish(ctx context.Context, req *pubsub.PublishRe Entries: msgArr, Topic: req.Topic, } - bulkHandler(context.Background(), nbm) + _, err = bulkHandler(context.Background(), nbm) } - return nil + return err } // BulkPublish is a mock publish method. Immediately call the handler for each event in request if topic is subscribed. @@ -4057,6 +4058,10 @@ func (m *mockSubscribePubSub) GetComponentMetadata() map[string]string { return map[string]string{} } +func (m *mockSubscribePubSub) GetBulkResponse() pubsub.BulkSubscribeResponse { + return m.bulkReponse +} + func TestPubSubDeadLetter(t *testing.T) { testDeadLetterPubsub := "failPubsub" pubsubComponent := componentsV1alpha1.Component{