(refs #110)Correct authentication and CommitHook

This commit is contained in:
Tomofumi Tanaka
2014-03-04 22:49:25 +09:00
parent 79e1abe624
commit 09b7e67c52
3 changed files with 99 additions and 66 deletions

View File

@@ -9,15 +9,15 @@ import util.Directory._
import org.eclipse.jgit.transport.{ReceivePack, UploadPack} import org.eclipse.jgit.transport.{ReceivePack, UploadPack}
import org.apache.sshd.server.command.UnknownCommand import org.apache.sshd.server.command.UnknownCommand
import servlet.{Database, CommitLogHook} import servlet.{Database, CommitLogHook}
import service.SystemSettingsService.SystemSettings
import service.SystemSettingsService import service.SystemSettingsService
import org.eclipse.jgit.errors.RepositoryNotFoundException
class GitCommandFactory extends CommandFactory { class GitCommandFactory extends CommandFactory {
private val logger = LoggerFactory.getLogger(classOf[GitCommandFactory]) private val logger = LoggerFactory.getLogger(classOf[GitCommandFactory])
override def createCommand(command: String): Command = { override def createCommand(command: String): Command = {
logger.info(s"command: String -> " + command) logger.debug(s"command: $command")
command match { command match {
// TODO MUST use regular expression and UnitTest // TODO MUST use regular expression and UnitTest
case s if s.startsWith("git-upload-pack") => new GitUploadPack(command) case s if s.startsWith("git-upload-pack") => new GitUploadPack(command)
@@ -28,24 +28,24 @@ class GitCommandFactory extends CommandFactory {
} }
abstract class GitCommand(val command: String) extends Command { abstract class GitCommand(val command: String) extends Command {
private val logger = LoggerFactory.getLogger(classOf[GitCommand]) protected val logger = LoggerFactory.getLogger(classOf[GitCommand])
protected val (gitCommand, owner, repositoryName) = parseCommand protected val (gitCommand, owner, repositoryName) = parseCommand
protected var err: OutputStream = null protected var err: OutputStream = null
protected var in: InputStream = null protected var in: InputStream = null
protected var out: OutputStream = null protected var out: OutputStream = null
protected var callback: ExitCallback = null protected var callback: ExitCallback = null
protected def runnable: Runnable protected def runnable(user: String): Runnable
override def start(env: Environment): Unit = { override def start(env: Environment): Unit = {
logger.info(s"start command : " + command) logger.info(s"start command : " + command)
logger.info(s"parsed command : $gitCommand, $owner, $repositoryName") logger.info(s"parsed command : $gitCommand, $owner, $repositoryName")
val thread = new Thread(runnable) val user = env.getEnv.get("USER")
val thread = new Thread(runnable(user))
thread.start() thread.start()
} }
override def destroy(): Unit = { override def destroy(): Unit = {}
}
override def setExitCallback(callback: ExitCallback): Unit = { override def setExitCallback(callback: ExitCallback): Unit = {
this.callback = callback this.callback = callback
@@ -64,47 +64,70 @@ abstract class GitCommand(val command: String) extends Command {
} }
private def parseCommand: (String, String, String) = { private def parseCommand: (String, String, String) = {
// command sample: git-upload-pack '/username/repository_name.git' // command sample: git-upload-pack '/owner/repository_name.git'
// command sample: git-receive-pack '/username/repository_name.git' // command sample: git-receive-pack '/owner/repository_name.git'
// TODO This is not correct.... // TODO This is not correct.... but works
val split = command.split(" ") val split = command.split(" ")
val gitCommand = split(0) val gitCommand = split(0)
val gitUser = split(1).substring(1, split(1).length - 5).split("/")(1) val owner = split(1).substring(1, split(1).length - 5).split("/")(1)
val gitRepo = split(1).substring(1, split(1).length - 5).split("/")(2) val repositoryName = split(1).substring(1, split(1).length - 5).split("/")(2)
(gitCommand, gitUser, gitRepo) (gitCommand, owner, repositoryName)
} }
} }
class GitUploadPack(command: String) extends GitCommand(command: String) { class GitUploadPack(override val command: String) extends GitCommand(command: String) {
override def runnable = new Runnable { override def runnable(user: String) = new Runnable {
override def run(): Unit = { override def run(): Unit = {
using(Git.open(getRepositoryDir(owner, repositoryName))) { git => try {
val repository = git.getRepository using(Git.open(getRepositoryDir(owner, repositoryName))) {
val upload = new UploadPack(repository) git =>
upload.upload(in, out, err) val repository = git.getRepository
callback.onExit(0) val upload = new UploadPack(repository)
try {
upload.upload(in, out, err)
callback.onExit(0)
} catch {
case e: Throwable =>
logger.error(e.getMessage, e)
callback.onExit(1)
}
}
} catch {
case e: RepositoryNotFoundException =>
logger.info(e.getMessage, e)
callback.onExit(1)
} }
} }
} }
} }
class GitReceivePack(command: String) extends GitCommand(command: String) with SystemSettingsService { class GitReceivePack(override val command: String) extends GitCommand(command: String) with SystemSettingsService {
override def runnable = new Runnable { // TODO Correct this info. where i get base url?
val BaseURL: String = loadSystemSettings().baseUrl.getOrElse("http://localhost:8080")
// TODO correct this info
val pusher: String = "user1"
val baseURL: String = loadSystemSettings().baseUrl.getOrElse("http://localhost:8080")
override def runnable(user: String) = new Runnable {
override def run(): Unit = { override def run(): Unit = {
using(Git.open(getRepositoryDir(owner, repositoryName))) { git => try {
val repository = git.getRepository using(Git.open(getRepositoryDir(owner, repositoryName))) {
// TODO hook commit git =>
val receive = new ReceivePack(repository) val repository = git.getRepository
receive.setPostReceiveHook(new CommitLogHook(owner, repositoryName, pusher, baseURL)) val receive = new ReceivePack(repository)
Database(SshServer.getServletContext) withTransaction { receive.setPostReceiveHook(new CommitLogHook(owner, repositoryName, user, BaseURL))
receive.receive(in, out, err) Database(SshServer.getServletContext) withTransaction {
callback.onExit(0) try {
} receive.receive(in, out, err)
callback.onExit(0)
} catch {
case e: Throwable =>
logger.error(e.getMessage, e)
callback.onExit(1)
}
}
}
} catch {
case e: RepositoryNotFoundException =>
logger.info(e.getMessage, e)
callback.onExit(1)
} }
} }
} }

View File

@@ -1,37 +1,49 @@
package ssh package ssh
import org.apache.sshd.server.{PublickeyAuthenticator, PasswordAuthenticator} import org.apache.sshd.server.PublickeyAuthenticator
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import org.apache.sshd.server.session.ServerSession import org.apache.sshd.server.session.ServerSession
import java.security.{KeyFactory, PublicKey} import java.security.PublicKey
import org.apache.commons.codec.binary.Base64 import org.apache.commons.codec.binary.Base64
import java.security.spec.X509EncodedKeySpec
import org.apache.sshd.common.util.Buffer import org.apache.sshd.common.util.Buffer
import org.eclipse.jgit.lib.Constants
object DummyData {
val userPublicKeys = List(
"ssh-rsa AAB3NzaC1yc2EAAAADAQABAAABAQDRzuX0WtSLzCY45nEhfFDPXzYGmvQdqnOgOUY4yGL5io/2ztyUvJdhWowkyakeoPxVk/jIP7Tu8Are5TuSD+fJp7aUbZW2CYOEsxo8cwndh/ezIX6RFjlu+xvKvZ8G7BtFLlLCcnza9uB+uEAyPH5HvGQLdV7dXctLfFqXPTr1p1RjSI7Noubm+vN4n9108rILd32MlhQiToXjL4HKWWwmppaln6bEsonOQW4/GieRjQeyWDkbVekIofnedjWl4+W0kAA+WosNwRFShgsaJLfU964HT/cGjK5auqOG+nATY0suECnxAK+5Wb6jXXYNmKiIMHypeXG1Qy2wMyMB1Gq9 tanacasino-local",
"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDRzuX0WtSLzCY45nEhfFDPXzYGmvQdqnOgOUY4yGL5io/2ztyUvJdhWowkyakeoPxVk/jIP7Tu8Are5TuSD+fJp7aUbZW2CYOEsxo8cwndh/ezIX6RFjlu+xvKvZ8G7BtFLlLCcnza9uB+uEAyPH5HvGQLdV7dXctLfFqXPTr1p1RjSI7Noubm+vN4n9108rILd32MlhQiToXjL4HKWWwmppaln6bEsonOQW4/GieRjQeyWDkbVekIofnedjWl4+W0kAA+WosNwRFShgsaJLfU964HT/cGjK5auqOG+nATY0suECnxAK+5Wb6jXXYNmKiIMHypeXG1Qy2wMyMB1Gq9 tanacasino-local"
)
}
class PublicKeyAuthenticator extends PublickeyAuthenticator { class PublicKeyAuthenticator extends PublickeyAuthenticator {
private val logger = LoggerFactory.getLogger(classOf[PublicKeyAuthenticator]) private val logger = LoggerFactory.getLogger(classOf[PublicKeyAuthenticator])
override def authenticate(username: String, key: PublicKey, session: ServerSession): Boolean = { override def authenticate(username: String, key: PublicKey, session: ServerSession): Boolean = {
// TODO this string is read from DB and Users register this public key string on Account Profile view // TODO userPublicKeys is read from DB and Users register this public key string list on Account Profile view
val testAuthkey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDRzuX0WtSLzCY45nEhfFDPXzYGmvQdqnOgOUY4yGL5io/2ztyUvJdhWowkyakeoPxVk/jIP7Tu8Are5TuSD+fJp7aUbZW2CYOEsxo8cwndh/ezIX6RFjlu+xvKvZ8G7BtFLlLCcnza9uB+uEAyPH5HvGQLdV7dXctLfFqXPTr1p1RjSI7Noubm+vN4n9108rILd32MlhQiToXjL4HKWWwmppaln6bEsonOQW4/GieRjQeyWDkbVekIofnedjWl4+W0kAA+WosNwRFShgsaJLfU964HT/cGjK5auqOG+nATY0suECnxAK+5Wb6jXXYNmKiIMHypeXG1Qy2wMyMB1Gq9 tanacasino-local" DummyData.userPublicKeys.exists(str => str2PublicKey(str) match {
toPublicKey(testAuthkey) match {
case Some(publicKey) => key.equals(publicKey) case Some(publicKey) => key.equals(publicKey)
case _ => false case _ => false
})
}
private def str2PublicKey(key: String): Option[PublicKey] = {
// TODO RFC 4716 Public Key is not supported...
val parts = key.split(" ")
if (parts.size < 2) {
logger.debug(s"Invalid PublicKey Format: key")
return None
}
try {
val encodedKey = parts(1)
val decode = Base64.decodeBase64(Constants.encodeASCII(encodedKey))
Some(new Buffer(decode).getRawPublicKey)
} catch {
case e: Throwable =>
logger.debug(e.getMessage, e)
None
} }
} }
private def toPublicKey(key: String): Option[PublicKey] = {
try {
val parts = key.split(" ")
val encodedKey = key.split(" ")(1)
val decode = Base64.decodeBase64(encodedKey)
Some(new Buffer(decode).getRawPublicKey)
} catch {
case e: Throwable => {
logger.error(e.getMessage, e)
None
}
}
}
} }

View File

@@ -3,30 +3,30 @@ package ssh
import javax.servlet.{ServletContext, ServletContextEvent, ServletContextListener} import javax.servlet.{ServletContext, ServletContextEvent, ServletContextListener}
import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import util.Directory
object SshServer { object SshServer {
private val logger = LoggerFactory.getLogger(SshServer.getClass) private val logger = LoggerFactory.getLogger(SshServer.getClass)
val DEFAULT_PORT: Int = 29418 // TODO read from config val DEFAULT_PORT: Int = 29418
val SSH_SERVICE_ENABLE = true // TODO read from config
val SSH_SERVICE_ENABLE = true // TODO read from config
private val server = org.apache.sshd.SshServer.setUpDefaultServer() private val server = org.apache.sshd.SshServer.setUpDefaultServer()
// TODO think other way to create database session // TODO think other way. this is for create database session
private var context: ServletContext = null private var context: ServletContext = null
private def configure() = { private def configure() = {
server.setPort(DEFAULT_PORT) server.setPort(DEFAULT_PORT) // TODO read from config
// TODO gitbucket.ser should be in GITBUCKET_HOME server.setKeyPairProvider(new SimpleGeneratorHostKeyProvider(s"${Directory.GitBucketHome}/gitbucket.ser"))
server.setKeyPairProvider(new SimpleGeneratorHostKeyProvider("gitbucket.ser"))
server.setPublickeyAuthenticator(new PublicKeyAuthenticator) server.setPublickeyAuthenticator(new PublicKeyAuthenticator)
server.setCommandFactory(new GitCommandFactory) server.setCommandFactory(new GitCommandFactory)
} }
def start(context: ServletContext) = { def start(context: ServletContext) = this.synchronized {
if (SSH_SERVICE_ENABLE) { if (SSH_SERVICE_ENABLE) {
this.context = context this.context = context
configure() configure()
@@ -39,7 +39,7 @@ object SshServer {
server.stop(true) server.stop(true)
} }
def getServletContext = this.context; def getServletContext = this.context
} }
/* /*
@@ -52,7 +52,7 @@ object SshServer {
class SshServerListener extends ServletContextListener { class SshServerListener extends ServletContextListener {
override def contextInitialized(sce: ServletContextEvent): Unit = { override def contextInitialized(sce: ServletContextEvent): Unit = {
SshServer.start(sce.getServletContext()) SshServer.start(sce.getServletContext)
} }
override def contextDestroyed(sce: ServletContextEvent): Unit = { override def contextDestroyed(sce: ServletContextEvent): Unit = {
@@ -60,5 +60,3 @@ class SshServerListener extends ServletContextListener {
} }
} }