Limit request methods to HEAD/GET

This commit is contained in:
NielsAD
2020-05-27 10:48:39 +02:00
parent 1bcc66add4
commit d838f56dd7
3 changed files with 21 additions and 7 deletions

View File

@@ -26,6 +26,7 @@ Usage
|`-d` |`string` |Database location|
|`-r` |`string` |Root directory to serve|
|`-i` |`string` |Refresh interval|
|`-l` |`int` |Request rate limit (req/sec per IP)|
|`-t` |`duration`|Request timeout|
|`-forwarded`|`bool` |Trust X-Real-IP and X-Forwarded-For headers|
|`-cached` |`bool` |Serve everything from cache (rather than search/recursive queries only)|

9
fs.go
View File

@@ -316,7 +316,7 @@ func (f Files) Less(i, j int) bool {
func (fs *CachedFS) serveCache(w http.ResponseWriter, r *http.Request) {
if !fs.DBReady() {
http.Error(w, "503 service unavailable", http.StatusServiceUnavailable)
http.Error(w, "503 Service Unavailable", http.StatusServiceUnavailable)
return
}
@@ -461,7 +461,7 @@ func (fs *CachedFS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Sitemap serves a list of all directories
func (fs *CachedFS) Sitemap(w http.ResponseWriter, r *http.Request) {
if !fs.DBReady() {
http.Error(w, "503 service unavailable", http.StatusServiceUnavailable)
http.Error(w, "503 Service Unavailable", http.StatusServiceUnavailable)
return
}
@@ -480,6 +480,9 @@ func (fs *CachedFS) Sitemap(w http.ResponseWriter, r *http.Request) {
}
defer rows.Close()
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("Cache-Control", "max-age=3600")
for rows.Next() {
var path string
err = rows.Scan(&path)
@@ -496,8 +499,6 @@ func (fs *CachedFS) Sitemap(w http.ResponseWriter, r *http.Request) {
goto interr
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("Cache-Control", "max-age=3600")
return
interr:

18
main.go
View File

@@ -28,6 +28,7 @@ var (
db = flag.String("d", ":memory:", "Database location")
dir = flag.String("r", ".", "Root directory to serve")
refresh = flag.String("i", "1h", "Refresh interval")
ratelimit = flag.Int64("l", 5, "Request rate limit (req/sec per IP)")
timeout = flag.Duration("t", time.Second, "Request timeout")
forwarded = flag.Bool("forwarded", false, "Trust X-Real-IP and X-Forwarded-For headers")
cached = flag.Bool("cached", false, "Serve everything from cache (rather than search/recursive queries only)")
@@ -84,15 +85,15 @@ func main() {
}
})
limit := stdlib.NewMiddleware(limiter.New(memory.NewStore(), limiter.Rate{Period: 1 * time.Second, Limit: 5}))
limit := stdlib.NewMiddleware(limiter.New(memory.NewStore(), limiter.Rate{Period: time.Second, Limit: *ratelimit}))
srv := &http.Server{Addr: *addr}
handleDefault := func(p string, h http.Handler) { http.Handle(p, realIP(*forwarded, h)) }
handleDefault := func(p string, h http.Handler) { http.Handle(p, realIP(*forwarded, checkMethod(h))) }
handleLimited := func(p string, h http.Handler) { handleDefault(p, limit.Handler(logRequest(http.StripPrefix(p, h)))) }
handleLimited("/idx/", fs)
handleLimited("/dl/", nodir(http.FileServer(http.Dir(fs.Root))))
handleDefault("/urllist.txt", http.HandlerFunc(fs.Sitemap))
handleLimited("/urllist.txt", http.HandlerFunc(fs.Sitemap))
handleDefault("/", pub)
go func() {
@@ -119,6 +120,17 @@ func orHyphen(s string) string {
return "-"
}
func checkMethod(han http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "HEAD" && r.Method != "GET" {
http.Error(w, "405 Method Not Allowed", http.StatusMethodNotAllowed)
return
}
han.ServeHTTP(w, r)
})
}
func realIP(trustForward bool, han http.Handler) http.Handler {
if !trustForward {
return han