2016-11-16 77 views
-1

我可以在下面的示例代碼中看到兩個主要問題,但我不知道如何正確解決它們。超時處理程序中的競爭條件

如果超時處理程序未通過errCh獲取下一個處理程序已完成或發生錯誤的信號,它將回復請求的「408請求超時」。

這裏的問題是ResponseWriter不安全,可能被多個goroutine使用。超時處理程序在執行下一個處理程序時啓動一個新的goroutine。

問題:

  1. 如何防止下一個處理從編寫到ResponseWriter當CTX的完成通道超時在超時處理程序。

  2. 如何防止超時處理程序在下一個處理程序正在寫入ResponseWriter時回覆408狀態碼,但尚未完成,並且ctx的Done通道在超時處理程序中超時。


package main 

import (
    "context" 
    "fmt" 
    "net/http" 
    "time" 
) 

func main() { 
    http.Handle("/race", handlerFunc(timeoutHandler)) 
    http.ListenAndServe(":8080", nil) 
} 

func timeoutHandler(w http.ResponseWriter, r *http.Request) error { 
    const seconds = 1 
    ctx, cancel := context.WithTimeout(r.Context(), time.Duration(seconds)*time.Second) 
    defer cancel() 

    r = r.WithContext(ctx) 

    errCh := make(chan error, 1) 
    go func() { 
    // w is not safe for concurrent use by multiple goroutines 
    errCh <- nextHandler(w, r) 
    }() 

    select { 
    case err := <-errCh: 
    return err 
    case <-ctx.Done(): 
    // w is not safe for concurrent use by multiple goroutines 
    http.Error(w, "Request timeout", 408) 
    return nil 
    } 
} 

func nextHandler(w http.ResponseWriter, r *http.Request) error { 
    // just for fun to simulate a better race condition 
    const seconds = 1 
    time.Sleep(time.Duration(seconds) * time.Second) 
    fmt.Fprint(w, "nextHandler") 
    return nil 
} 

type handlerFunc func(w http.ResponseWriter, r *http.Request) error 

func (fn handlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) { 
    if err := fn(w, r); err != nil { 
    http.Error(w, "Server error", 500) 
    } 
} 
+2

如何通過其他的東西比原來的'ResponseWriter'設置爲'nextHandler()'?然後您必須將結果複製回'<-errCh'情況下的原始'ResponseWriter'。 –

+0

無論要寫入ResponseWriter,還需要對超時負責。你正在爲nextHandler提供一個超時上下文,所以應該能夠處理超時本身。一般來說,如果只有一個處理程序負責編寫響應,則會更容易。 – JimB

回答

0

這裏是一個可能的解決方案,它基於@安迪的評論。

responseRecorder將被傳遞給nextHandler,而記錄的響應將被複制回客戶端:

func timeoutHandler(w http.ResponseWriter, r *http.Request) error { 
    const seconds = 1 
    ctx, cancel := context.WithTimeout(r.Context(), 
     time.Duration(seconds)*time.Second) 
    defer cancel() 

    r = r.WithContext(ctx) 

    errCh := make(chan error, 1) 
    w2 := newResponseRecorder() 
    go func() { 
     errCh <- nextHandler(w2, r) 
    }() 

    select { 
    case err := <-errCh: 
     if err != nil { 
      return err 
     } 

     w2.cloneHeader(w.Header()) 
     w.WriteHeader(w2.status) 
     w.Write(w2.buf.Bytes()) 
     return nil 
    case <-ctx.Done(): 
     http.Error(w, "Request timeout", 408) 
     return nil 
    } 
} 

這裏是responseRecorder

type responseRecorder struct { 
    http.ResponseWriter 
    header http.Header 
    buf *bytes.Buffer 
    status int 
} 

func newResponseRecorder() *responseRecorder { 
    return &responseRecorder{ 
     header: http.Header{}, 
     buf: &bytes.Buffer{}, 
    } 
} 

func (w *responseRecorder) Header() http.Header { 
    return w.header 
} 

func (w *responseRecorder) cloneHeader(dst http.Header) { 
    for k, v := range w.header { 
     tmp := make([]string, len(v)) 
     copy(tmp, v) 
     dst[k] = tmp 
    } 
} 

func (w *responseRecorder) Write(data []byte) (int, error) { 
    if w.status == 0 { 
     w.WriteHeader(http.StatusOK) 
    } 
    return w.buf.Write(data) 
} 

func (w *responseRecorder) WriteHeader(status int) { 
    w.status = status 
}