From e3d52a0bbd2dc1753e761a44c20c22c94090f329 Mon Sep 17 00:00:00 2001 From: Vaci Koblizek Date: Mon, 23 Nov 2020 12:44:46 +0000 Subject: [PATCH] use gather writes for AsynchronousSocketChannels --- .../java/org/capnproto/TwoPartyClient.java | 9 +- .../java/org/capnproto/TwoPartyServer.java | 4 +- .../org/capnproto/TwoPartyVatNetwork.java | 5 +- .../main/java/org/capnproto/Serialize.java | 224 +++++++++++++----- .../java/org/capnproto/SerializeTest.java | 10 +- 5 files changed, 181 insertions(+), 71 deletions(-) diff --git a/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java b/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java index d1d3159..5c21a48 100644 --- a/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java +++ b/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java @@ -1,7 +1,6 @@ package org.capnproto; -import java.io.IOException; -import java.nio.channels.AsynchronousByteChannel; +import java.nio.channels.AsynchronousSocketChannel; import java.util.concurrent.CompletableFuture; public class TwoPartyClient { @@ -9,15 +8,15 @@ public class TwoPartyClient { private final TwoPartyVatNetwork network; private final RpcSystem rpcSystem; - public TwoPartyClient(AsynchronousByteChannel channel) { + public TwoPartyClient(AsynchronousSocketChannel channel) { this(channel, null); } - public TwoPartyClient(AsynchronousByteChannel channel, Capability.Client bootstrapInterface) { + public TwoPartyClient(AsynchronousSocketChannel channel, Capability.Client bootstrapInterface) { this(channel, bootstrapInterface, RpcTwoPartyProtocol.Side.CLIENT); } - public TwoPartyClient(AsynchronousByteChannel channel, + public TwoPartyClient(AsynchronousSocketChannel channel, Capability.Client bootstrapInterface, RpcTwoPartyProtocol.Side side) { this.network = new TwoPartyVatNetwork(channel, side); diff --git a/runtime-rpc/src/main/java/org/capnproto/TwoPartyServer.java b/runtime-rpc/src/main/java/org/capnproto/TwoPartyServer.java index 9bab21f..7f39f16 100644 --- a/runtime-rpc/src/main/java/org/capnproto/TwoPartyServer.java +++ b/runtime-rpc/src/main/java/org/capnproto/TwoPartyServer.java @@ -8,11 +8,11 @@ import java.util.concurrent.CompletableFuture; public class TwoPartyServer { private class AcceptedConnection { - private final AsynchronousByteChannel connection; + private final AsynchronousSocketChannel connection; private final TwoPartyVatNetwork network; private final RpcSystem rpcSystem; - AcceptedConnection(Capability.Client bootstrapInterface, AsynchronousByteChannel connection) { + AcceptedConnection(Capability.Client bootstrapInterface, AsynchronousSocketChannel connection) { this.connection = connection; this.network = new TwoPartyVatNetwork(this.connection, RpcTwoPartyProtocol.Side.SERVER); this.rpcSystem = new RpcSystem<>(network, bootstrapInterface); diff --git a/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java b/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java index 965d812..70df26d 100644 --- a/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java +++ b/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java @@ -1,6 +1,5 @@ package org.capnproto; -import java.nio.channels.AsynchronousByteChannel; import java.nio.channels.AsynchronousSocketChannel; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -11,12 +10,12 @@ public class TwoPartyVatNetwork private CompletableFuture previousWrite = CompletableFuture.completedFuture(null); private final CompletableFuture disconnectPromise = new CompletableFuture<>(); - private final AsynchronousByteChannel channel; + private final AsynchronousSocketChannel channel; private final RpcTwoPartyProtocol.Side side; private final MessageBuilder peerVatId = new MessageBuilder(4); private boolean accepted; - public TwoPartyVatNetwork(AsynchronousByteChannel channel, RpcTwoPartyProtocol.Side side) { + public TwoPartyVatNetwork(AsynchronousSocketChannel channel, RpcTwoPartyProtocol.Side side) { this.channel = channel; this.side = side; this.peerVatId.initRoot(RpcTwoPartyProtocol.VatId.factory).setSide( diff --git a/runtime/src/main/java/org/capnproto/Serialize.java b/runtime/src/main/java/org/capnproto/Serialize.java index 7bdc158..8b51963 100644 --- a/runtime/src/main/java/org/capnproto/Serialize.java +++ b/runtime/src/main/java/org/capnproto/Serialize.java @@ -24,14 +24,10 @@ package org.capnproto; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.nio.channels.AsynchronousByteChannel; -import java.nio.channels.CompletionHandler; -import java.nio.channels.ReadableByteChannel; -import java.nio.channels.WritableByteChannel; +import java.nio.channels.*; import java.util.ArrayList; -import java.util.Arrays; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; public final class Serialize { @@ -208,14 +204,11 @@ public final class Serialize { } } - static final class AsyncMessageReader { - - private final AsynchronousByteChannel channel; + static abstract class AsyncMessageReader { private final ReaderOptions options; - private final CompletableFuture readCompleted = new CompletableFuture<>(); + protected final CompletableFuture readCompleted = new CompletableFuture<>(); - public AsyncMessageReader(AsynchronousByteChannel channel, ReaderOptions options) { - this.channel = channel; + AsyncMessageReader(ReaderOptions options) { this.options = options; } @@ -226,8 +219,8 @@ public final class Serialize { private void readHeader() { read(Constants.BYTES_PER_WORD, firstWord -> { - final var segmentCount = 1 + firstWord.getInt(0); - final var segment0Size = firstWord.getInt(4); + var segmentCount = 1 + firstWord.getInt(0); + var segment0Size = firstWord.getInt(4); if (segmentCount == 1) { readSegments(segment0Size, segmentCount, segment0Size, null); @@ -241,7 +234,7 @@ public final class Serialize { } read(4 * (segmentCount & ~1), moreSizesRaw -> { - final var moreSizes = new int[segmentCount - 1]; + var moreSizes = new int[segmentCount - 1]; var totalWords = segment0Size; for (int ii = 0; ii < segmentCount - 1; ++ii) { @@ -262,7 +255,7 @@ public final class Serialize { return; } - final var segmentSlices = new ByteBuffer[segmentCount]; + var segmentSlices = new ByteBuffer[segmentCount]; if (totalWords == 0) { for (int ii = 0; ii < segmentCount; ++ii) { segmentSlices[ii] = ByteBuffer.allocate(0); @@ -273,17 +266,19 @@ public final class Serialize { read(totalWords * Constants.BYTES_PER_WORD, allSegments -> { allSegments.rewind(); - segmentSlices[0] = allSegments.slice(); - segmentSlices[0].limit(segment0Size * Constants.BYTES_PER_WORD); - segmentSlices[0].order(ByteOrder.LITTLE_ENDIAN); + var segment0 = allSegments.slice(); + segment0.limit(segment0Size * Constants.BYTES_PER_WORD); + segment0.order(ByteOrder.LITTLE_ENDIAN); + segmentSlices[0] = segment0; int offset = segment0Size; for (int ii = 1; ii < segmentCount; ++ii) { allSegments.position(offset * Constants.BYTES_PER_WORD); var segmentSize = moreSizes[ii-1]; - segmentSlices[ii] = allSegments.slice(); - segmentSlices[ii].limit(segmentSize * Constants.BYTES_PER_WORD); - segmentSlices[ii].order(ByteOrder.LITTLE_ENDIAN); + var segment = allSegments.slice(); + segment.limit(segmentSize * Constants.BYTES_PER_WORD); + segment.order(ByteOrder.LITTLE_ENDIAN); + segmentSlices[ii] = segment; offset += segmentSize; } @@ -291,19 +286,71 @@ public final class Serialize { }); } - private void read(int bytes, Consumer consumer) { - final var buffer = Serialize.makeByteBuffer(bytes); - final var handler = new CompletionHandler() { + abstract void read(int bytes, Consumer consumer); + } + + static class AsyncSocketReader extends AsyncMessageReader { + private final AsynchronousSocketChannel channel; + private final long timeout; + private final TimeUnit timeUnit; + + AsyncSocketReader(AsynchronousSocketChannel channel, ReaderOptions options, long timeout, TimeUnit timeUnit) { + super(options); + this.channel = channel; + this.timeout = timeout; + this.timeUnit = timeUnit; + } + + void read(int bytes, Consumer consumer) { + var buffer = Serialize.makeByteBuffer(bytes); + var handler = new CompletionHandler() { @Override public void completed(Integer result, Object attachment) { - // System.out.println("read " + result + " bytes"); + //System.out.println(channel.toString() + ": read " + result + " bytes"); if (result <= 0) { var text = result == 0 ? "Read zero bytes. Is the channel in non-blocking mode?" : "Premature EOF"; readCompleted.completeExceptionally(new IOException(text)); } else if (buffer.hasRemaining()) { - // retry + // partial read + channel.read(buffer, timeout, timeUnit, null, this); + } else { + consumer.accept(buffer); + } + } + + @Override + public void failed(Throwable exc, Object attachment) { + readCompleted.completeExceptionally(exc); + } + }; + + this.channel.read(buffer, this.timeout, this.timeUnit, null, handler); + } + } + + static class AsyncByteChannelReader extends AsyncMessageReader { + private final AsynchronousByteChannel channel; + + AsyncByteChannelReader(AsynchronousByteChannel channel, ReaderOptions options) { + super(options); + this.channel = channel; + } + + void read(int bytes, Consumer consumer) { + var buffer = Serialize.makeByteBuffer(bytes); + var handler = new CompletionHandler() { + @Override + public void completed(Integer result, Object attachment) { + //System.out.println(channel.toString() + ": read " + result + " bytes"); + if (result <= 0) { + var text = result == 0 + ? "Read zero bytes. Is the channel in non-blocking mode?" + : "Premature EOF"; + readCompleted.completeExceptionally(new IOException(text)); + } else if (buffer.hasRemaining()) { + // partial read channel.read(buffer, null, this); } else { consumer.accept(buffer); @@ -325,39 +372,51 @@ public final class Serialize { } public static CompletableFuture readAsync(AsynchronousByteChannel channel, ReaderOptions options) { - return new AsyncMessageReader(channel, options).getMessage(); + return new AsyncByteChannelReader(channel, options).getMessage(); + } + + public static CompletableFuture readAsync(AsynchronousSocketChannel channel) { + return readAsync(channel, ReaderOptions.DEFAULT_READER_OPTIONS, Long.MAX_VALUE, TimeUnit.SECONDS); + } + + public static CompletableFuture readAsync(AsynchronousSocketChannel channel, ReaderOptions options) { + return readAsync(channel, options, Long.MAX_VALUE, TimeUnit.SECONDS); + } + + public static CompletableFuture readAsync(AsynchronousSocketChannel channel, long timeout, TimeUnit timeUnit) { + return readAsync(channel, ReaderOptions.DEFAULT_READER_OPTIONS, timeout, timeUnit); + } + + public static CompletableFuture readAsync(AsynchronousSocketChannel channel, ReaderOptions options, long timeout, TimeUnit timeUnit) { + return new AsyncSocketReader(channel, options, timeout, timeUnit).getMessage(); } public static CompletableFuture writeAsync(AsynchronousByteChannel outputChannel, MessageBuilder message) { - final var writeCompleted = new CompletableFuture(); - final var segments = message.getSegmentsForOutput(); - assert segments.length > 0: "Empty message"; - final int tableSize = (segments.length + 2) & (~1); - final var table = ByteBuffer.allocate(4 * tableSize); - - table.order(ByteOrder.LITTLE_ENDIAN); - table.putInt(0, segments.length - 1); - - for (int ii = 0; ii < segments.length; ++ii) { - table.putInt(4 * (ii + 1), segments[ii].limit() / 8); - } - - outputChannel.write(table, 0, new CompletionHandler<>() { + var writeCompleted = new CompletableFuture(); + var segments = message.getSegmentsForOutput(); + var header = getHeaderForOutput(segments); + outputChannel.write(header, -1, new CompletionHandler<>() { @Override - public void completed(Integer result, Integer attachment) { - //System.out.println("Wrote " + result + " bytes"); - if (writeCompleted.isCancelled()) { - // TODO do we really want to interrupt here? - return; - } + public void completed(Integer result, Integer index) { + var currentSegment = index < 0 ? header : segments[index]; - if (attachment == segments.length) { - writeCompleted.complete(null); - return; + if (result < 0) { + writeCompleted.completeExceptionally(new IOException("Write failed")); + } + else if (currentSegment.hasRemaining()) { + // partial write + outputChannel.write(currentSegment, index, this); + } + else { + index++; + if (index == segments.length) { + writeCompleted.complete(null); + } + else { + outputChannel.write(segments[index], index, this); + } } - - outputChannel.write(segments[attachment], attachment + 1, this); } @Override @@ -365,6 +424,63 @@ public final class Serialize { writeCompleted.completeExceptionally(exc); } }); - return writeCompleted.copy(); + + return writeCompleted; + } + + public static CompletableFuture writeAsync(AsynchronousSocketChannel outputChannel, MessageBuilder message) { + return writeAsync(outputChannel, message, Long.MAX_VALUE, TimeUnit.SECONDS); + } + + public static CompletableFuture writeAsync(AsynchronousSocketChannel outputChannel, MessageBuilder message, long timeout, TimeUnit timeUnit) { + var writeCompleted = new CompletableFuture(); + var segments = message.getSegmentsForOutput(); + var header = getHeaderForOutput(segments); + long totalBytes = header.remaining(); + + // TODO avoid this copy? + var allSegments = new ByteBuffer[segments.length+1]; + allSegments[0] = header; + for (int ii = 0; ii < segments.length; ++ii) { + var segment = segments[ii]; + allSegments[ii+1] = segment; + totalBytes += segment.remaining(); + } + + outputChannel.write(allSegments, 0, allSegments.length, timeout, timeUnit, totalBytes, new CompletionHandler<>() { + @Override + public void completed(Long result, Long totalBytes) { + //System.out.println(outputChannel.toString() + ": Wrote " + result + "/" + totalBytes + " bytes"); + if (result < 0) { + writeCompleted.completeExceptionally(new IOException("Write failed")); + } + else if (result < totalBytes) { + // partial write + outputChannel.write(allSegments, 0, allSegments.length, timeout, timeUnit, totalBytes - result, this); + } + else { + writeCompleted.complete(null); + } + } + + @Override + public void failed(Throwable exc, Long attachment) { + writeCompleted.completeExceptionally(exc); + } + }); + + return writeCompleted; + } + + private static ByteBuffer getHeaderForOutput(ByteBuffer[] segments) { + assert segments.length > 0: "Empty message"; + int tableSize = (segments.length + 2) & (~1); + var table = ByteBuffer.allocate(4 * tableSize); + table.order(ByteOrder.LITTLE_ENDIAN); + table.putInt(0, segments.length - 1); + for (int ii = 0; ii < segments.length; ++ii) { + table.putInt(4 * (ii + 1), segments[ii].limit() / 8); + } + return table; } } diff --git a/runtime/src/test/java/org/capnproto/SerializeTest.java b/runtime/src/test/java/org/capnproto/SerializeTest.java index aecd486..d2245bb 100644 --- a/runtime/src/test/java/org/capnproto/SerializeTest.java +++ b/runtime/src/test/java/org/capnproto/SerializeTest.java @@ -22,7 +22,6 @@ package org.capnproto; import java.io.IOException; -import java.net.SocketOptions; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousServerSocketChannel; import java.nio.channels.AsynchronousSocketChannel; @@ -72,10 +71,10 @@ public class SerializeTest { } // read via AsyncChannel - expectSerializesToAsync(exampleSegmentCount, exampleBytes); + expectSerializesToAsyncSocket(exampleSegmentCount, exampleBytes); } - private void expectSerializesToAsync(int exampleSegmentCount, byte[] exampleBytes) throws IOException { + private void expectSerializesToAsyncSocket(int exampleSegmentCount, byte[] exampleBytes) throws IOException { var done = new CompletableFuture(); var server = AsynchronousServerSocketChannel.open(); server.bind(null); @@ -108,10 +107,7 @@ public class SerializeTest { checkSegmentContents(exampleSegmentCount, messageReader.arena); done.get(); } - catch (InterruptedException exc) { - Assert.fail(exc.getMessage()); - } - catch (ExecutionException exc) { + catch (InterruptedException | ExecutionException exc) { Assert.fail(exc.getMessage()); } }