diff --git a/pkg/apiserver/auditing/types.go b/pkg/apiserver/auditing/types.go index 775c84c8277b1a70f6e86ca2e540a197a6072aed..353072df0ef6b0701ae8745c53eb8be2ec9d9803 100644 --- a/pkg/apiserver/auditing/types.go +++ b/pkg/apiserver/auditing/types.go @@ -1,8 +1,10 @@ package auditing import ( + "bufio" "bytes" "encoding/json" + "fmt" "github.com/google/uuid" "io/ioutil" "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -14,6 +16,7 @@ import ( "kubesphere.io/kubesphere/pkg/apiserver/request" "kubesphere.io/kubesphere/pkg/client/listers/auditing/v1alpha1" "kubesphere.io/kubesphere/pkg/utils/iputil" + "net" "net/http" "time" ) @@ -186,7 +189,6 @@ type ResponseCapture struct { wroteHeader bool status int body *bytes.Buffer - StopCh chan interface{} } func NewResponseCapture(w http.ResponseWriter) *ResponseCapture { @@ -194,7 +196,6 @@ func NewResponseCapture(w http.ResponseWriter) *ResponseCapture { ResponseWriter: w, wroteHeader: false, body: new(bytes.Buffer), - StopCh: make(chan interface{}, 1), } } @@ -204,10 +205,6 @@ func (c *ResponseCapture) Header() http.Header { func (c *ResponseCapture) Write(data []byte) (int, error) { - defer func() { - c.StopCh <- struct{}{} - }() - c.WriteHeader(http.StatusOK) c.body.Write(data) return c.ResponseWriter.Write(data) @@ -216,6 +213,7 @@ func (c *ResponseCapture) Write(data []byte) (int, error) { func (c *ResponseCapture) WriteHeader(statusCode int) { if !c.wroteHeader { c.status = statusCode + c.ResponseWriter.WriteHeader(statusCode) c.wroteHeader = true } } @@ -227,3 +225,14 @@ func (c *ResponseCapture) Bytes() []byte { func (c *ResponseCapture) StatusCode() int { return c.status } + +// Hijack implements the http.Hijacker interface. This expands +// the Response to fulfill http.Hijacker if the underlying +// http.ResponseWriter supports it. +func (c *ResponseCapture) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := c.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("ResponseWriter doesn't support Hijacker interface") + } + return hijacker.Hijack() +} diff --git a/pkg/apiserver/auditing/types_test.go b/pkg/apiserver/auditing/types_test.go index 2e5367bbe0b1e71735aae1ca560c55f9b4dcd24d..85ccf8f37769eaaba6df3e73d4153c6ae46077ae 100644 --- a/pkg/apiserver/auditing/types_test.go +++ b/pkg/apiserver/auditing/types_test.go @@ -16,6 +16,7 @@ import ( ksinformers "kubesphere.io/kubesphere/pkg/client/informers/externalversions" "kubesphere.io/kubesphere/pkg/utils/iputil" "net/http" + "net/http/httptest" "net/url" "testing" "time" @@ -248,7 +249,7 @@ func TestAuditing_LogResponseObject(t *testing.T) { e := a.LogRequestObject(req, info) - resp := &ResponseCapture{} + resp := NewResponseCapture(httptest.NewRecorder()) resp.WriteHeader(200) a.LogResponseObject(e, resp, info) @@ -295,3 +296,29 @@ func TestAuditing_LogResponseObject(t *testing.T) { assert.EqualValues(t, string(expectedBs), string(bs)) } + +func TestResponseCapture_WriteHeader(t *testing.T) { + record := httptest.NewRecorder() + resp := NewResponseCapture(record) + + resp.WriteHeader(404) + + assert.EqualValues(t, 404, resp.StatusCode()) + assert.EqualValues(t, 404, record.Code) +} + +func TestResponseCapture_Write(t *testing.T) { + + record := httptest.NewRecorder() + resp := NewResponseCapture(record) + + body := []byte("123") + + _, err := resp.Write(body) + if err != nil { + panic(err) + } + + assert.EqualValues(t, body, resp.Bytes()) + assert.EqualValues(t, body, record.Body.Bytes()) +}