mirror of
https://github.com/scm-manager/scm-manager.git
synced 2026-05-07 04:47:01 +02:00
Implement more robust socket hook protocol
This commit is contained in:
@@ -79,37 +79,47 @@ class DefaultHookHandler implements HookHandler {
|
||||
LOG.warn("failed to read hook request", e);
|
||||
} finally {
|
||||
LOG.trace("close client socket");
|
||||
TransactionId.clear();
|
||||
close();
|
||||
}
|
||||
}
|
||||
|
||||
private void handleHookRequest(InputStream input, OutputStream output) throws IOException {
|
||||
Request request = Sockets.read(input, Request.class);
|
||||
Request request = Sockets.receive(input, Request.class);
|
||||
TransactionId.set(request.getTransactionId());
|
||||
Response response = handleHookRequest(request);
|
||||
Sockets.send(output, response);
|
||||
}
|
||||
|
||||
private Response handleHookRequest(Request request) {
|
||||
LOG.trace("process {} hook for node {}", request.getType(), request.getNode());
|
||||
TransactionId.set(request.getTransactionId());
|
||||
|
||||
if (!environment.isAcceptAble(request.getChallenge())) {
|
||||
LOG.warn("received hook with invalid challenge: {}", request.getChallenge());
|
||||
return error("invalid hook challenge");
|
||||
}
|
||||
|
||||
try {
|
||||
authenticate(request);
|
||||
|
||||
return fireHook(request);
|
||||
} catch (AuthenticationException ex) {
|
||||
LOG.warn("hook authentication failed", ex);
|
||||
return error("hook authentication failed");
|
||||
}
|
||||
}
|
||||
|
||||
@Nonnull
|
||||
private Response fireHook(Request request) {
|
||||
HgHookContextProvider context = hookContextProviderFactory.create(request.getRepositoryId(), request.getNode());
|
||||
|
||||
try {
|
||||
if (!environment.isAcceptAble(request.getChallenge())) {
|
||||
LOG.warn("received hook with invalid challenge: {}", request.getChallenge());
|
||||
return error("invalid hook challenge");
|
||||
}
|
||||
|
||||
authenticate(request);
|
||||
environment.setPending(request.getType() == RepositoryHookType.PRE_RECEIVE);
|
||||
|
||||
hookEventFacade.handle(request.getRepositoryId()).fireHookEvent(request.getType(), context);
|
||||
|
||||
return new Response(context.getHgMessageProvider().getMessages(), false);
|
||||
} catch (AuthenticationException ex) {
|
||||
LOG.warn("hook authentication failed", ex);
|
||||
return error("hook authentication failed");
|
||||
|
||||
} catch (NotFoundException ex) {
|
||||
LOG.warn("could not find repository with id {}", request.getRepositoryId(), ex);
|
||||
return error("repository not found");
|
||||
@@ -121,7 +131,6 @@ class DefaultHookHandler implements HookHandler {
|
||||
return error(context, "unknown error");
|
||||
} finally {
|
||||
environment.clearPendingState();
|
||||
TransactionId.clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,14 +25,20 @@
|
||||
package sonia.scm.repository.hooks;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.EOFException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
|
||||
class Sockets {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(Sockets.class);
|
||||
|
||||
private static final int READ_LIMIT = 8192;
|
||||
|
||||
private static final ObjectMapper objectMapper = new ObjectMapper();
|
||||
|
||||
private Sockets() {
|
||||
@@ -40,18 +46,54 @@ class Sockets {
|
||||
|
||||
static void send(OutputStream out, Object object) throws IOException {
|
||||
byte[] bytes = objectMapper.writeValueAsBytes(object);
|
||||
LOG.trace("send message length of {} to socket", bytes.length);
|
||||
writeInt(out, bytes.length);
|
||||
LOG.trace("send message to socket");
|
||||
out.write(bytes);
|
||||
out.write('\0');
|
||||
LOG.trace("flush socket");
|
||||
out.flush();
|
||||
}
|
||||
|
||||
static <T> T read(InputStream in, Class<T> type) throws IOException {
|
||||
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
|
||||
int c = in.read();
|
||||
while (c != '\0') {
|
||||
buffer.write(c);
|
||||
c = in.read();
|
||||
static <T> T receive(InputStream in, Class<T> type) throws IOException {
|
||||
LOG.trace("read {} from socket", type);
|
||||
int length = readInt(in);
|
||||
LOG.trace("read message length of {} from socket", length);
|
||||
if (length > READ_LIMIT) {
|
||||
String message = String.format("received length of %d, which exceeds the limit of %d", length, READ_LIMIT);
|
||||
throw new IOException(message);
|
||||
}
|
||||
return objectMapper.readValue(buffer.toByteArray(), type);
|
||||
byte[] data = read(in, length);
|
||||
LOG.trace("convert message to {}", type);
|
||||
return objectMapper.readValue(data, type);
|
||||
}
|
||||
|
||||
private static void writeInt(OutputStream out, int value) throws IOException {
|
||||
out.write((value >>> 24) & 0xFF);
|
||||
out.write((value >>> 16) & 0xFF);
|
||||
out.write((value >>> 8) & 0xFF);
|
||||
out.write(value & 0xFF);
|
||||
}
|
||||
|
||||
private static int readInt(InputStream in) throws IOException {
|
||||
int b1 = in.read();
|
||||
int b2 = in.read();
|
||||
int b3 = in.read();
|
||||
int b4 = in.read();
|
||||
|
||||
if ((b1 | b2 | b3 | b4) < 0) {
|
||||
throw new EOFException("failed to read int from socket");
|
||||
}
|
||||
|
||||
return ((b1 << 24) + (b2 << 16) + (b3 << 8) + b4);
|
||||
}
|
||||
|
||||
private static byte[] read(InputStream in, int length) throws IOException {
|
||||
byte[] buffer = new byte[length];
|
||||
int read = in.read(buffer);
|
||||
if (read < length) {
|
||||
throw new EOFException("failed to read bytes from socket");
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
# changegroup.scm = python:scmhooks.callback
|
||||
#
|
||||
|
||||
import os, sys, json, socket
|
||||
import os, sys, json, socket, struct
|
||||
|
||||
# read environment
|
||||
port = os.environ['SCM_HOOK_PORT']
|
||||
@@ -54,17 +54,19 @@ def fire_hook(ui, repo, hooktype, node):
|
||||
values = {'token': token, 'type': hooktype, 'repositoryId': repositoryId, 'transactionId': transactionId, 'challenge': challenge, 'node': node.decode('utf8') }
|
||||
|
||||
connection.connect(("127.0.0.1", int(port)))
|
||||
connection.send(json.dumps(values).encode('utf-8'))
|
||||
connection.sendall(b'\0')
|
||||
|
||||
received = []
|
||||
byte = connection.recv(1)
|
||||
while byte != b'\0':
|
||||
received.append(byte)
|
||||
byte = connection.recv(1)
|
||||
data = json.dumps(values).encode('utf-8')
|
||||
connection.send(struct.pack('>i', len(data)))
|
||||
connection.sendall(data)
|
||||
|
||||
message = b''.join(received).decode('utf-8')
|
||||
response = json.loads(message)
|
||||
d = connection.recv(4, socket.MSG_WAITALL)
|
||||
length = struct.unpack('>i', bytearray(d))[0]
|
||||
if length > 8192:
|
||||
ui.warn( b"scm-hook received message with exceeds the limit of 8192\n" )
|
||||
return True
|
||||
|
||||
d = connection.recv(length, socket.MSG_WAITALL)
|
||||
response = json.loads(d.decode("utf-8"))
|
||||
|
||||
abort = response['abort']
|
||||
print_messages(ui, response['messages'])
|
||||
@@ -94,7 +96,7 @@ def pre_hook(ui, repo, hooktype, node=None, source=None, pending=None, **kwargs)
|
||||
|
||||
# newer mercurial version
|
||||
# we have to make in-memory changes visible to external process
|
||||
# this does not happen automatically, because mercurial treat our hooks as internal hooks
|
||||
# this does not happen automatically, because mercurial treat our hooks as internal hook
|
||||
# see hook.py at mercurial sources _exthook
|
||||
try:
|
||||
if repo is not None:
|
||||
@@ -103,7 +105,7 @@ def pre_hook(ui, repo, hooktype, node=None, source=None, pending=None, **kwargs)
|
||||
if tr and not tr.writepending():
|
||||
ui.warn(b"no pending write transaction found")
|
||||
except AttributeError:
|
||||
ui.debug(b"mercurial does not support currenttransation")
|
||||
ui.debug(b"mercurial does not support currenttransaction")
|
||||
# do nothing
|
||||
|
||||
return callback(ui, repo, "PRE_RECEIVE", node)
|
||||
|
||||
@@ -291,14 +291,16 @@ class DefaultHookHandlerTest {
|
||||
private DefaultHookHandler.Response send(DefaultHookHandler.Request request) throws IOException {
|
||||
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
|
||||
Sockets.send(buffer, request);
|
||||
|
||||
ByteArrayInputStream input = new ByteArrayInputStream(buffer.toByteArray());
|
||||
when(socket.getInputStream()).thenReturn(input);
|
||||
|
||||
ByteArrayOutputStream output = new ByteArrayOutputStream();
|
||||
when(socket.getOutputStream()).thenReturn(output);
|
||||
|
||||
handler.run();
|
||||
|
||||
return Sockets.read(new ByteArrayInputStream(output.toByteArray()), DefaultHookHandler.Response.class);
|
||||
return Sockets.receive(new ByteArrayInputStream(output.toByteArray()), DefaultHookHandler.Response.class);
|
||||
}
|
||||
|
||||
private static class TestingException extends ExceptionWithContext {
|
||||
|
||||
@@ -82,7 +82,7 @@ class HookServerTest {
|
||||
OutputStream output = socket.getOutputStream()
|
||||
) {
|
||||
Sockets.send(output, request);
|
||||
return Sockets.read(input, Response.class);
|
||||
return Sockets.receive(input, Response.class);
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException("failed", ex);
|
||||
}
|
||||
@@ -100,7 +100,7 @@ class HookServerTest {
|
||||
@Override
|
||||
public void run() {
|
||||
try (InputStream input = socket.getInputStream(); OutputStream output = socket.getOutputStream()) {
|
||||
Request request = Sockets.read(input, Request.class);
|
||||
Request request = Sockets.receive(input, Request.class);
|
||||
Subject subject = SecurityUtils.getSubject();
|
||||
Sockets.send(output, new Response("Hello " + request.getName(), subject.getPrincipal().toString()));
|
||||
} catch (IOException ex) {
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
/*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020-present Cloudogu GmbH and Contributors
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*/
|
||||
|
||||
package sonia.scm.repository.hooks;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
class SocketsTest {
|
||||
|
||||
@Test
|
||||
void shouldSendAndReceive() throws IOException {
|
||||
ByteArrayOutputStream output = new ByteArrayOutputStream();
|
||||
Sockets.send(output, new TestValue("awesome"));
|
||||
ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray());
|
||||
TestValue value = Sockets.receive(input, TestValue.class);
|
||||
assertThat(value.value).isEqualTo("awesome");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailWithTooFewBytesForLength() {
|
||||
ByteArrayOutputStream output = new ByteArrayOutputStream();
|
||||
output.write((512 >>> 24) & 0xFF);
|
||||
output.write((512 >>> 16) & 0xFF);
|
||||
|
||||
ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray());
|
||||
IOException ex = assertThrows(IOException.class, () -> Sockets.receive(input, TestValue.class));
|
||||
assertThat(ex.getMessage()).contains("int");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailWithTooFewBytesForData() {
|
||||
ByteArrayOutputStream output = new ByteArrayOutputStream();
|
||||
output.write((16 >>> 24) & 0xFF);
|
||||
output.write((16 >>> 16) & 0xFF);
|
||||
output.write((16 >>> 8) & 0xFF);
|
||||
output.write(16 & 0xFF);
|
||||
|
||||
ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray());
|
||||
IOException ex = assertThrows(IOException.class, () -> Sockets.receive(input, TestValue.class));
|
||||
assertThat(ex.getMessage()).contains("bytes");
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldFailIfLimitIsExceeded() {
|
||||
ByteArrayOutputStream output = new ByteArrayOutputStream();
|
||||
output.write((9216 >>> 24) & 0xFF);
|
||||
output.write((9216 >>> 16) & 0xFF);
|
||||
output.write((9216 >>> 8) & 0xFF);
|
||||
output.write(9216 & 0xFF);
|
||||
|
||||
ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray());
|
||||
IOException ex = assertThrows(IOException.class, () -> Sockets.receive(input, TestValue.class));
|
||||
assertThat(ex.getMessage()).contains("9216");
|
||||
}
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public static class TestValue {
|
||||
|
||||
private String value;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user