|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- package middleware
-
- import (
- "bytes"
- "crypto/hmac"
- "crypto/sha1"
- "crypto/subtle"
- "encoding/base64"
- "encoding/hex"
- "fmt"
- "net/http"
- "net/url"
- "regexp"
- "strconv"
- "time"
-
- "github.com/goanywhere/crypto"
- )
-
- const (
- xsrfCookieName = "xsrf"
- xsrfHeaderName = "X-XSRF-Token"
- xsrfFieldName = "xsrftoken"
-
- xsrfMaxAge = 3600 * 24 * 365
- xsrfTimeout = time.Hour * 24 * 365
- )
-
- var (
- errXSRFReferer = "Referer URL is missing from the request or the value was malformed."
- errXSRFToken = "Invalid XSRF tokens"
-
- xsrfPattern = regexp.MustCompile("[^0-9a-zA-Z-_]")
- unsafeMethods = regexp.MustCompile("^(DELETE|POST|PUT)$")
- )
-
- type xsrf struct {
- *http.Request
- http.ResponseWriter
- token string
- }
-
- // See http://en.wikipedia.org/wiki/Same-origin_policy
- func (self *xsrf) checkOrigin() bool {
- if self.Request.URL.Scheme == "https" {
- // See [OWASP]; Checking the Referer Header.
- referer, err := url.Parse(self.Request.Header.Get("Referer"))
-
- if err != nil || referer.String() == "" ||
- referer.Scheme != self.Request.URL.Scheme ||
- referer.Host != self.Request.URL.Host {
-
- return false
- }
- }
- return true
- }
-
- func (self *xsrf) checkToken(token string) bool {
- // Header always takes precedance of form field since some popular
- // JavaScript frameworks allow global custom headers for all AJAX requests.
- query := self.Request.Header.Get(xsrfFieldName)
- if query == "" {
- query = self.Request.FormValue(xsrfFieldName)
- }
-
- // 1) basic length comparison.
- if query == "" || len(query) != len(token) {
- return false
- }
- // *sanitize* incoming masked token.
- query = xsrfPattern.ReplaceAllString(query, "")
-
- // 2) byte-based comparison.
- a, _ := base64.URLEncoding.DecodeString(token)
- b, _ := base64.URLEncoding.DecodeString(query)
- if subtle.ConstantTimeCompare(a, b) != 1 {
- return false
- }
-
- // 3) issued time checking.
- index := bytes.LastIndex(b, []byte{'|'})
- if index != 40 {
- return false
- }
-
- nanos, err := strconv.ParseInt(string(b[index+1:]), 10, 64)
- if err != nil {
- return false
- }
- now := time.Now()
- issueTime := time.Unix(0, nanos)
-
- if now.Sub(issueTime) >= xsrfTimeout {
- return false
- }
-
- // Ensure the token is not from the *future*, allow 1 minute grace period.
- if issueTime.After(now.Add(1 * time.Minute)) {
- return false
- }
-
- return true
- }
-
- func (self *xsrf) generate() {
- // Ensure we have XSRF token in the cookie first.
- var token string
- if cookie, err := self.Request.Cookie(xsrfCookieName); err == nil {
- if cookie.Value != "" {
- token = cookie.Value
- }
- }
- if token == "" {
- // Generate a base64-encoded token.
- nano := time.Now().UnixNano()
- hash := hmac.New(sha1.New, []byte(crypto.Random(32)))
- fmt.Fprintf(hash, "%s|%d", crypto.Random(12), nano)
- raw := fmt.Sprintf("%s|%d", hex.EncodeToString(hash.Sum(nil)), nano)
- token = base64.URLEncoding.EncodeToString([]byte(raw))
-
- // The max-age directive takes priority over Expires.
- // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
- cookie := new(http.Cookie)
- cookie.Name = xsrfCookieName
- cookie.Value = token
- cookie.MaxAge = xsrfMaxAge
- cookie.Path = "/"
- cookie.HttpOnly = true
- http.SetCookie(self.ResponseWriter, cookie)
- }
- self.ResponseWriter.Header()[xsrfHeaderName] = []string{token}
- self.token = token
- }
-
- // XSRF serves as Cross-Site Request Forgery protection middleware.
- func XSRF(next http.Handler) http.Handler {
- fn := func(w http.ResponseWriter, r *http.Request) {
- x := new(xsrf)
- x.Request = r
- x.ResponseWriter = w
- x.generate()
-
- if unsafeMethods.MatchString(r.Method) {
- // Ensure the URL came for "Referer" under HTTPS.
- if !x.checkOrigin() {
- http.Error(w, errXSRFReferer, http.StatusForbidden)
- }
-
- // length => bytes => issue time checkpoints.
- if !x.checkToken(x.token) {
- http.Error(w, errXSRFToken, http.StatusForbidden)
- }
- }
-
- // ensure browser will invalidate the cached XSRF token.
- w.Header().Add("Vary", "Cookie")
-
- next.ServeHTTP(w, r)
- }
- return http.HandlerFunc(fn)
- }
|