diff --git a/README.md b/README.md index 570d1f4..e18652b 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ Usage |`-d` |`string` |Database location| |`-r` |`string` |Root directory to serve| |`-i` |`string` |Refresh interval| +|`-l` |`int` |Request rate limit (req/sec per IP)| |`-t` |`duration`|Request timeout| |`-forwarded`|`bool` |Trust X-Real-IP and X-Forwarded-For headers| |`-cached` |`bool` |Serve everything from cache (rather than search/recursive queries only)| diff --git a/fs.go b/fs.go index ca10de0..8a24b9c 100644 --- a/fs.go +++ b/fs.go @@ -316,7 +316,7 @@ func (f Files) Less(i, j int) bool { func (fs *CachedFS) serveCache(w http.ResponseWriter, r *http.Request) { if !fs.DBReady() { - http.Error(w, "503 service unavailable", http.StatusServiceUnavailable) + http.Error(w, "503 Service Unavailable", http.StatusServiceUnavailable) return } @@ -461,7 +461,7 @@ func (fs *CachedFS) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Sitemap serves a list of all directories func (fs *CachedFS) Sitemap(w http.ResponseWriter, r *http.Request) { if !fs.DBReady() { - http.Error(w, "503 service unavailable", http.StatusServiceUnavailable) + http.Error(w, "503 Service Unavailable", http.StatusServiceUnavailable) return } @@ -480,6 +480,9 @@ func (fs *CachedFS) Sitemap(w http.ResponseWriter, r *http.Request) { } defer rows.Close() + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Cache-Control", "max-age=3600") + for rows.Next() { var path string err = rows.Scan(&path) @@ -496,8 +499,6 @@ func (fs *CachedFS) Sitemap(w http.ResponseWriter, r *http.Request) { goto interr } - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("Cache-Control", "max-age=3600") return interr: diff --git a/main.go b/main.go index 0c30576..b5a23ad 100644 --- a/main.go +++ b/main.go @@ -28,6 +28,7 @@ var ( db = flag.String("d", ":memory:", "Database location") dir = flag.String("r", ".", "Root directory to serve") refresh = flag.String("i", "1h", "Refresh interval") + ratelimit = flag.Int64("l", 5, "Request rate limit (req/sec per IP)") timeout = flag.Duration("t", time.Second, "Request timeout") forwarded = flag.Bool("forwarded", false, "Trust X-Real-IP and X-Forwarded-For headers") cached = flag.Bool("cached", false, "Serve everything from cache (rather than search/recursive queries only)") @@ -84,15 +85,15 @@ func main() { } }) - limit := stdlib.NewMiddleware(limiter.New(memory.NewStore(), limiter.Rate{Period: 1 * time.Second, Limit: 5})) + limit := stdlib.NewMiddleware(limiter.New(memory.NewStore(), limiter.Rate{Period: time.Second, Limit: *ratelimit})) srv := &http.Server{Addr: *addr} - handleDefault := func(p string, h http.Handler) { http.Handle(p, realIP(*forwarded, h)) } + handleDefault := func(p string, h http.Handler) { http.Handle(p, realIP(*forwarded, checkMethod(h))) } handleLimited := func(p string, h http.Handler) { handleDefault(p, limit.Handler(logRequest(http.StripPrefix(p, h)))) } handleLimited("/idx/", fs) handleLimited("/dl/", nodir(http.FileServer(http.Dir(fs.Root)))) - handleDefault("/urllist.txt", http.HandlerFunc(fs.Sitemap)) + handleLimited("/urllist.txt", http.HandlerFunc(fs.Sitemap)) handleDefault("/", pub) go func() { @@ -119,6 +120,17 @@ func orHyphen(s string) string { return "-" } +func checkMethod(han http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "HEAD" && r.Method != "GET" { + http.Error(w, "405 Method Not Allowed", http.StatusMethodNotAllowed) + return + } + + han.ServeHTTP(w, r) + }) +} + func realIP(trustForward bool, han http.Handler) http.Handler { if !trustForward { return han