Serialization to and from AsynchronousByteChannel

This commit is contained in:
Vaci Koblizek 2020-09-27 23:09:05 +01:00
parent 73bc7a6569
commit 86dfbd123d
2 changed files with 219 additions and 2 deletions

View file

@ -22,11 +22,17 @@
package org.capnproto;
import java.io.IOException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.AsynchronousByteChannel;
import java.nio.channels.CompletionHandler;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Consumer;
public final class Serialize {
@ -201,4 +207,164 @@ public final class Serialize {
}
}
}
static final class AsyncMessageReader {
private final AsynchronousByteChannel channel;
private final ReaderOptions options;
private final CompletableFuture<MessageReader> readCompleted = new CompletableFuture<>();
public AsyncMessageReader(AsynchronousByteChannel channel, ReaderOptions options) {
this.channel = channel;
this.options = options;
}
public CompletableFuture<MessageReader> getMessage() {
readHeader();
return readCompleted;
}
private void readHeader() {
read(Constants.BYTES_PER_WORD, firstWord -> {
final var segmentCount = 1 + firstWord.getInt(0);
final var segment0Size = firstWord.getInt(4);
if (segmentCount == 1) {
readSegments(segment0Size, segmentCount, segment0Size, null);
return;
}
// check before allocating segment size buffer
if (segmentCount > 512) {
readCompleted.completeExceptionally(new IOException("Too many segments"));
return;
}
read(4 * (segmentCount & ~1), moreSizesRaw -> {
final var moreSizes = new int[segmentCount - 1];
var totalWords = segment0Size;
for (int ii = 0; ii < segmentCount - 1; ++ii) {
int size = moreSizesRaw.getInt(ii * 4);
moreSizes[ii] = size;
totalWords += size;
}
readSegments(totalWords, segmentCount, segment0Size, moreSizes);
});
});
}
private void readSegments(int totalWords, int segmentCount, int segment0Size, int[] moreSizes) {
if (totalWords > options.traversalLimitInWords) {
readCompleted.completeExceptionally(
new DecodeException("Message size exceeds traversal limit."));
return;
}
final var segmentSlices = new ByteBuffer[segmentCount];
if (totalWords == 0) {
for (int ii = 0; ii < segmentCount; ++ii) {
segmentSlices[ii] = ByteBuffer.allocate(0);
}
readCompleted.complete(new MessageReader(segmentSlices, options));
return;
}
read(totalWords * Constants.BYTES_PER_WORD, allSegments -> {
allSegments.rewind();
segmentSlices[0] = allSegments.slice();
segmentSlices[0].limit(segment0Size * Constants.BYTES_PER_WORD);
segmentSlices[0].order(ByteOrder.LITTLE_ENDIAN);
int offset = segment0Size;
for (int ii = 1; ii < segmentCount; ++ii) {
allSegments.position(offset * Constants.BYTES_PER_WORD);
var segmentSize = moreSizes[ii-1];
segmentSlices[ii] = allSegments.slice();
segmentSlices[ii].limit(segmentSize * Constants.BYTES_PER_WORD);
segmentSlices[ii].order(ByteOrder.LITTLE_ENDIAN);
offset += segmentSize;
}
readCompleted.complete(new MessageReader(segmentSlices, options));
});
}
private void read(int bytes, Consumer<ByteBuffer> consumer) {
final var buffer = Serialize.makeByteBuffer(bytes);
final var handler = new CompletionHandler<Integer, Object>() {
@Override
public void completed(Integer result, Object attachment) {
// System.out.println("read " + result + " bytes");
if (result <= 0) {
var text = result == 0
? "Read zero bytes. Is the channel in non-blocking mode?"
: "Premature EOF";
readCompleted.completeExceptionally(new IOException(text));
} else if (buffer.hasRemaining()) {
// retry
channel.read(buffer, null, this);
} else {
consumer.accept(buffer);
}
}
@Override
public void failed(Throwable exc, Object attachment) {
readCompleted.completeExceptionally(exc);
}
};
this.channel.read(buffer, null, handler);
}
}
public static CompletableFuture<MessageReader> readAsync(AsynchronousByteChannel channel) {
return readAsync(channel, ReaderOptions.DEFAULT_READER_OPTIONS);
}
public static CompletableFuture<MessageReader> readAsync(AsynchronousByteChannel channel, ReaderOptions options) {
return new AsyncMessageReader(channel, options).getMessage();
}
public static CompletableFuture<java.lang.Void> writeAsync(AsynchronousByteChannel outputChannel, MessageBuilder message) {
final var writeCompleted = new CompletableFuture<java.lang.Void>();
final var segments = message.getSegmentsForOutput();
assert segments.length > 0: "Empty message";
final int tableSize = (segments.length + 2) & (~1);
final var table = ByteBuffer.allocate(4 * tableSize);
table.order(ByteOrder.LITTLE_ENDIAN);
table.putInt(0, segments.length - 1);
for (int ii = 0; ii < segments.length; ++ii) {
table.putInt(4 * (ii + 1), segments[ii].limit() / 8);
}
outputChannel.write(table, 0, new CompletionHandler<Integer, Integer>() {
@Override
public void completed(Integer result, Integer attachment) {
//System.out.println("Wrote " + result + " bytes");
if (writeCompleted.isCancelled()) {
// TODO do we really want to interrupt here?
return;
}
if (attachment == segments.length) {
writeCompleted.complete(null);
return;
}
outputChannel.write(segments[attachment], attachment + 1, this);
}
@Override
public void failed(Throwable exc, Integer attachment) {
writeCompleted.completeExceptionally(exc);
}
});
return writeCompleted.copy();
}
}

View file

@ -21,7 +21,14 @@
package org.capnproto;
import java.io.IOException;
import java.net.SocketOptions;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousServerSocketChannel;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import org.junit.Assert;
import org.junit.Test;
@ -63,6 +70,50 @@ public class SerializeTest {
MessageReader messageReader = Serialize.read(ByteBuffer.wrap(exampleBytes));
checkSegmentContents(exampleSegmentCount, messageReader.arena);
}
// read via AsyncChannel
expectSerializesToAsync(exampleSegmentCount, exampleBytes);
}
private void expectSerializesToAsync(int exampleSegmentCount, byte[] exampleBytes) throws IOException {
var done = new CompletableFuture<java.lang.Void>();
var server = AsynchronousServerSocketChannel.open();
server.bind(null);
server.accept(null, new CompletionHandler<AsynchronousSocketChannel, Object>() {
@Override
public void completed(AsynchronousSocketChannel socket, Object attachment) {
socket.write(ByteBuffer.wrap(exampleBytes), null, new CompletionHandler<Integer, Object>() {
@Override
public void completed(Integer result, Object attachment) {
done.complete(null);
}
@Override
public void failed(Throwable exc, Object attachment) {
done.completeExceptionally(exc);
}
});
}
@Override
public void failed(Throwable exc, Object attachment) {
done.completeExceptionally(exc);
}
});
var socket = AsynchronousSocketChannel.open();
try {
socket.connect(server.getLocalAddress()).get();
var messageReader = Serialize.readAsync(socket).get();
checkSegmentContents(exampleSegmentCount, messageReader.arena);
done.get();
}
catch (InterruptedException exc) {
Assert.fail(exc.getMessage());
}
catch (ExecutionException exc) {
Assert.fail(exc.getMessage());
}
}
@Test