Bypass Scalatra if request target isn’t Scalatra controller

This commit is contained in:
Naoki Takezoe
2019-01-05 22:17:46 +09:00
parent 5ce72e2056
commit d194681981
2 changed files with 47 additions and 36 deletions

View File

@@ -7,7 +7,32 @@ import org.scalatra.ScalatraFilter
import scala.collection.mutable.ListBuffer
class CompositeScalatraFilter extends Filter {
abstract class ControllerFilter extends Filter {
def process(request: ServletRequest, response: ServletResponse, checkPath: String): Boolean
override def doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain): Unit = {
val contextPath = request.getServletContext.getContextPath
val requestPath = request.asInstanceOf[HttpServletRequest].getRequestURI.substring(contextPath.length)
val checkPath = if (requestPath.endsWith("/")) {
requestPath
} else {
requestPath + "/"
}
if (!checkPath.startsWith("/upload/") && !checkPath.startsWith("/git/") && !checkPath.startsWith("/git-lfs/") &&
!checkPath.startsWith("/assets/") && !checkPath.startsWith("/plugin-assets/")) {
val continue = process(request, response, checkPath)
if (!continue) {
return ()
}
}
chain.doFilter(request, response)
}
}
class CompositeScalatraFilter extends ControllerFilter {
private val filters = new ListBuffer[(ScalatraFilter, String)]()
@@ -29,34 +54,23 @@ class CompositeScalatraFilter extends Filter {
}
}
override def doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain): Unit = {
val contextPath = request.getServletContext.getContextPath
val requestPath = request.asInstanceOf[HttpServletRequest].getRequestURI.substring(contextPath.length)
val checkPath = if (requestPath.endsWith("/")) {
requestPath
} else {
requestPath + "/"
}
override def process(request: ServletRequest, response: ServletResponse, checkPath: String): Boolean = {
filters
.filter {
case (_, path) =>
val start = path.replaceFirst("/\\*$", "/")
checkPath.startsWith(start)
}
.foreach {
case (filter, _) =>
val mockChain = new MockFilterChain()
filter.doFilter(request, response, mockChain)
if (mockChain.continue == false) {
return false
}
}
if (!checkPath.startsWith("/upload/") && !checkPath.startsWith("/git/") && !checkPath.startsWith("/git-lfs/") &&
!checkPath.startsWith("/plugin-assets/")) {
filters
.filter {
case (_, path) =>
val start = path.replaceFirst("/\\*$", "/")
checkPath.startsWith(start)
}
.foreach {
case (filter, _) =>
val mockChain = new MockFilterChain()
filter.doFilter(request, response, mockChain)
if (mockChain.continue == false) {
return ()
}
}
}
chain.doFilter(request, response)
true
}
}

View File

@@ -6,7 +6,7 @@ import javax.servlet.http.HttpServletRequest
import gitbucket.core.controller.ControllerBase
import gitbucket.core.plugin.PluginRegistry
class PluginControllerFilter extends Filter {
class PluginControllerFilter extends ControllerFilter {
private var filterConfig: FilterConfig = null
@@ -21,16 +21,13 @@ class PluginControllerFilter extends Filter {
}
}
override def doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain): Unit = {
val contextPath = request.getServletContext.getContextPath
val requestUri = request.asInstanceOf[HttpServletRequest].getRequestURI.substring(contextPath.length)
override def process(request: ServletRequest, response: ServletResponse, checkPath: String): Boolean = {
PluginRegistry()
.getControllers()
.filter {
case (_, path) =>
val start = path.replaceFirst("/\\*$", "/")
(requestUri + "/").startsWith(start)
checkPath.startsWith(start)
}
.foreach {
case (controller, _) =>
@@ -42,11 +39,11 @@ class PluginControllerFilter extends Filter {
controller.doFilter(request, response, mockChain)
if (mockChain.continue == false) {
return ()
return false
}
}
chain.doFilter(request, response)
true
}
}