From 86dfbd123db05d888c121ececc9ca781104e135e Mon Sep 17 00:00:00 2001 From: Vaci Koblizek Date: Sun, 27 Sep 2020 23:09:05 +0100 Subject: [PATCH] Serialization to and from AsynchronousByteChannel --- .../main/java/org/capnproto/Serialize.java | 170 +++++++++++++++++- .../java/org/capnproto/SerializeTest.java | 51 ++++++ 2 files changed, 219 insertions(+), 2 deletions(-) diff --git a/runtime/src/main/java/org/capnproto/Serialize.java b/runtime/src/main/java/org/capnproto/Serialize.java index 7dbfac0..60b829a 100644 --- a/runtime/src/main/java/org/capnproto/Serialize.java +++ b/runtime/src/main/java/org/capnproto/Serialize.java @@ -22,11 +22,17 @@ package org.capnproto; import java.io.IOException; -import java.nio.channels.ReadableByteChannel; -import java.nio.channels.WritableByteChannel; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.channels.AsynchronousByteChannel; +import java.nio.channels.CompletionHandler; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; import java.util.ArrayList; +import java.util.Arrays; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.function.Consumer; public final class Serialize { @@ -201,4 +207,164 @@ public final class Serialize { } } } + + static final class AsyncMessageReader { + + private final AsynchronousByteChannel channel; + private final ReaderOptions options; + private final CompletableFuture readCompleted = new CompletableFuture<>(); + + public AsyncMessageReader(AsynchronousByteChannel channel, ReaderOptions options) { + this.channel = channel; + this.options = options; + } + + public CompletableFuture getMessage() { + readHeader(); + return readCompleted; + } + + private void readHeader() { + read(Constants.BYTES_PER_WORD, firstWord -> { + final var segmentCount = 1 + firstWord.getInt(0); + final var segment0Size = firstWord.getInt(4); + + if (segmentCount == 1) { + readSegments(segment0Size, segmentCount, segment0Size, null); + return; + } + + // check before allocating segment size buffer + if (segmentCount > 512) { + readCompleted.completeExceptionally(new IOException("Too many segments")); + return; + } + + read(4 * (segmentCount & ~1), moreSizesRaw -> { + final var moreSizes = new int[segmentCount - 1]; + var totalWords = segment0Size; + + for (int ii = 0; ii < segmentCount - 1; ++ii) { + int size = moreSizesRaw.getInt(ii * 4); + moreSizes[ii] = size; + totalWords += size; + } + + readSegments(totalWords, segmentCount, segment0Size, moreSizes); + }); + }); + } + + private void readSegments(int totalWords, int segmentCount, int segment0Size, int[] moreSizes) { + if (totalWords > options.traversalLimitInWords) { + readCompleted.completeExceptionally( + new DecodeException("Message size exceeds traversal limit.")); + return; + } + + final var segmentSlices = new ByteBuffer[segmentCount]; + if (totalWords == 0) { + for (int ii = 0; ii < segmentCount; ++ii) { + segmentSlices[ii] = ByteBuffer.allocate(0); + } + readCompleted.complete(new MessageReader(segmentSlices, options)); + return; + } + + read(totalWords * Constants.BYTES_PER_WORD, allSegments -> { + allSegments.rewind(); + segmentSlices[0] = allSegments.slice(); + segmentSlices[0].limit(segment0Size * Constants.BYTES_PER_WORD); + segmentSlices[0].order(ByteOrder.LITTLE_ENDIAN); + + int offset = segment0Size; + for (int ii = 1; ii < segmentCount; ++ii) { + allSegments.position(offset * Constants.BYTES_PER_WORD); + var segmentSize = moreSizes[ii-1]; + segmentSlices[ii] = allSegments.slice(); + segmentSlices[ii].limit(segmentSize * Constants.BYTES_PER_WORD); + segmentSlices[ii].order(ByteOrder.LITTLE_ENDIAN); + offset += segmentSize; + } + + readCompleted.complete(new MessageReader(segmentSlices, options)); + }); + } + + private void read(int bytes, Consumer consumer) { + final var buffer = Serialize.makeByteBuffer(bytes); + final var handler = new CompletionHandler() { + @Override + public void completed(Integer result, Object attachment) { + // System.out.println("read " + result + " bytes"); + if (result <= 0) { + var text = result == 0 + ? "Read zero bytes. Is the channel in non-blocking mode?" + : "Premature EOF"; + readCompleted.completeExceptionally(new IOException(text)); + } else if (buffer.hasRemaining()) { + // retry + channel.read(buffer, null, this); + } else { + consumer.accept(buffer); + } + } + + @Override + public void failed(Throwable exc, Object attachment) { + readCompleted.completeExceptionally(exc); + } + }; + + this.channel.read(buffer, null, handler); + } + } + + public static CompletableFuture readAsync(AsynchronousByteChannel channel) { + return readAsync(channel, ReaderOptions.DEFAULT_READER_OPTIONS); + } + + public static CompletableFuture readAsync(AsynchronousByteChannel channel, ReaderOptions options) { + return new AsyncMessageReader(channel, options).getMessage(); + } + + public static CompletableFuture writeAsync(AsynchronousByteChannel outputChannel, MessageBuilder message) { + final var writeCompleted = new CompletableFuture(); + final var segments = message.getSegmentsForOutput(); + assert segments.length > 0: "Empty message"; + final int tableSize = (segments.length + 2) & (~1); + final var table = ByteBuffer.allocate(4 * tableSize); + + table.order(ByteOrder.LITTLE_ENDIAN); + table.putInt(0, segments.length - 1); + + for (int ii = 0; ii < segments.length; ++ii) { + table.putInt(4 * (ii + 1), segments[ii].limit() / 8); + } + + outputChannel.write(table, 0, new CompletionHandler() { + + @Override + public void completed(Integer result, Integer attachment) { + //System.out.println("Wrote " + result + " bytes"); + if (writeCompleted.isCancelled()) { + // TODO do we really want to interrupt here? + return; + } + + if (attachment == segments.length) { + writeCompleted.complete(null); + return; + } + + outputChannel.write(segments[attachment], attachment + 1, this); + } + + @Override + public void failed(Throwable exc, Integer attachment) { + writeCompleted.completeExceptionally(exc); + } + }); + return writeCompleted.copy(); + } } diff --git a/runtime/src/test/java/org/capnproto/SerializeTest.java b/runtime/src/test/java/org/capnproto/SerializeTest.java index 5b5c5f4..aecd486 100644 --- a/runtime/src/test/java/org/capnproto/SerializeTest.java +++ b/runtime/src/test/java/org/capnproto/SerializeTest.java @@ -21,7 +21,14 @@ package org.capnproto; +import java.io.IOException; +import java.net.SocketOptions; import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousServerSocketChannel; +import java.nio.channels.AsynchronousSocketChannel; +import java.nio.channels.CompletionHandler; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import org.junit.Assert; import org.junit.Test; @@ -63,6 +70,50 @@ public class SerializeTest { MessageReader messageReader = Serialize.read(ByteBuffer.wrap(exampleBytes)); checkSegmentContents(exampleSegmentCount, messageReader.arena); } + + // read via AsyncChannel + expectSerializesToAsync(exampleSegmentCount, exampleBytes); + } + + private void expectSerializesToAsync(int exampleSegmentCount, byte[] exampleBytes) throws IOException { + var done = new CompletableFuture(); + var server = AsynchronousServerSocketChannel.open(); + server.bind(null); + server.accept(null, new CompletionHandler() { + @Override + public void completed(AsynchronousSocketChannel socket, Object attachment) { + socket.write(ByteBuffer.wrap(exampleBytes), null, new CompletionHandler() { + @Override + public void completed(Integer result, Object attachment) { + done.complete(null); + } + + @Override + public void failed(Throwable exc, Object attachment) { + done.completeExceptionally(exc); + } + }); + } + + @Override + public void failed(Throwable exc, Object attachment) { + done.completeExceptionally(exc); + } + }); + + var socket = AsynchronousSocketChannel.open(); + try { + socket.connect(server.getLocalAddress()).get(); + var messageReader = Serialize.readAsync(socket).get(); + checkSegmentContents(exampleSegmentCount, messageReader.arena); + done.get(); + } + catch (InterruptedException exc) { + Assert.fail(exc.getMessage()); + } + catch (ExecutionException exc) { + Assert.fail(exc.getMessage()); + } } @Test