From beec84a1bcbf21259a716b6976bca846026b44dd Mon Sep 17 00:00:00 2001 From: Vaci Koblizek Date: Mon, 23 Nov 2020 20:50:11 +0000 Subject: [PATCH] calling thread drives client side loop --- .../src/main/java/org/capnproto/RpcState.java | 61 +++---- .../main/java/org/capnproto/RpcSystem.java | 19 ++- .../java/org/capnproto/TwoPartyClient.java | 8 + .../java/org/capnproto/TwoPartyServer.java | 59 ++----- .../org/capnproto/TwoPartyVatNetwork.java | 31 +--- .../test/java/org/capnproto/RpcStateTest.java | 161 ------------------ .../src/test/java/org/capnproto/RpcTest.java | 84 +++++---- .../test/java/org/capnproto/TwoPartyTest.java | 68 ++++---- 8 files changed, 158 insertions(+), 333 deletions(-) delete mode 100644 runtime-rpc/src/test/java/org/capnproto/RpcStateTest.java diff --git a/runtime-rpc/src/main/java/org/capnproto/RpcState.java b/runtime-rpc/src/main/java/org/capnproto/RpcState.java index 73b03a3..220d559 100644 --- a/runtime-rpc/src/main/java/org/capnproto/RpcState.java +++ b/runtime-rpc/src/main/java/org/capnproto/RpcState.java @@ -255,7 +255,6 @@ final class RpcState { this.bootstrapFactory = bootstrapFactory; this.connection = connection; this.disconnectFulfiller = disconnectFulfiller; - startMessageLoop(); } @Override @@ -391,36 +390,41 @@ final class RpcState { return pipeline.getPipelinedCap(new PipelineOp[0]); } - private void startMessageLoop() { + /** + * Returns a CompletableFuture that, when complete, has processed one message. + */ + public CompletableFuture pollOnce() { if (isDisconnected()) { this.messageLoop.completeExceptionally(this.disconnected); - return; + return CompletableFuture.failedFuture(this.disconnected); } - var messageReader = this.connection.receiveIncomingMessage() - .thenAccept(message -> { - if (message == null) { - this.disconnect(RpcException.disconnected("Peer disconnected")); - this.messageLoop.complete(null); - return; - } - try { - this.handleMessage(message); - - while (!this.lastEvals.isEmpty()) { - this.lastEvals.remove().call(); + return this.connection.receiveIncomingMessage() + .thenAccept(message -> { + if (message == null) { + this.disconnect(RpcException.disconnected("Peer disconnected")); + this.messageLoop.complete(null); + return; } + try { + this.handleMessage(message); + while (!this.lastEvals.isEmpty()) { + this.lastEvals.remove().call(); + } + } + catch (Throwable rpcExc) { + // either we received an Abort message from peer + // or internal RpcState is bad. + this.disconnect(rpcExc); + } + }); + } - } - catch (Throwable rpcExc) { - // either we received an Abort message from peer - // or internal RpcState is bad. - this.disconnect(rpcExc); - } - }); - - messageReader.thenRunAsync(this::startMessageLoop).exceptionallyCompose( - CompletableFuture::failedFuture); + public void runMessageLoop() { + this.pollOnce().thenRun(this::runMessageLoop).exceptionally(exc -> { + LOGGER.warning(() -> "Event loop exited: " + exc.getMessage()); + return null; + }); } private void handleMessage(IncomingRpcMessage message) throws RpcException { @@ -766,7 +770,6 @@ final class RpcState { } // This import is an unfulfilled promise. - assert !imp.promise.isDone(); switch (resolve.which()) { case CAP -> { var cap = receiveCap(resolve.getCap(), message.getAttachedFds()); @@ -981,10 +984,8 @@ final class RpcState { var resolve = message.getBody().initAs(RpcProtocol.Message.factory).initResolve(); resolve.setPromiseId(exportId); FromException(exc, resolve.initException()); - LOGGER.log(Level.INFO, this.toString() + ": > RESOLVE", exc.getMessage()); + LOGGER.info(() -> this.toString() + ": > RESOLVE FAILED export=" + exportId + " msg=" + exc.getMessage()); message.send(); - - // TODO disconnect? }); } @@ -1900,6 +1901,7 @@ final class RpcState { var replacementBrand = replacement.getBrand(); boolean isSameConnection = replacementBrand == RpcState.this; if (isSameConnection) { + // We resolved to some other RPC capability hosted by the same peer. var promise = replacement.whenMoreResolved(); if (promise != null) { var other = (PromiseClient)replacement; @@ -1936,6 +1938,7 @@ final class RpcState { // TODO Flow control if (resolutionType == ResolutionType.REFLECTED && receivedCall && !isDisconnected()) { + LOGGER.fine(() -> RpcState.this.toString() + ": embargoing reflected capability " + this.toString()); // The new capability is hosted locally, not on the remote machine. And, we had made calls // to the promise. We need to make sure those calls echo back to us before we allow new // calls to go directly to the local capability, so we need to set a local embargo and send diff --git a/runtime-rpc/src/main/java/org/capnproto/RpcSystem.java b/runtime-rpc/src/main/java/org/capnproto/RpcSystem.java index caee17e..7f7c181 100644 --- a/runtime-rpc/src/main/java/org/capnproto/RpcSystem.java +++ b/runtime-rpc/src/main/java/org/capnproto/RpcSystem.java @@ -1,14 +1,14 @@ package org.capnproto; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; public class RpcSystem { private final VatNetwork network; private final BootstrapFactory bootstrapFactory; - private final Map, RpcState> connections = new ConcurrentHashMap<>(); + private final Map, RpcState> connections = new HashMap<>(); public RpcSystem(VatNetwork network) { this(network, clientId -> new Capability.Client( @@ -29,7 +29,6 @@ public class RpcSystem { BootstrapFactory bootstrapFactory) { this.network = network; this.bootstrapFactory = bootstrapFactory; - this.startAcceptLoop(); } public Capability.Client bootstrap(VatId vatId) { @@ -45,7 +44,8 @@ public class RpcSystem { } public void accept(VatNetwork.Connection connection) { - getConnectionState(connection); + var state = getConnectionState(connection); + state.runMessageLoop(); } private RpcState getConnectionState(VatNetwork.Connection connection) { @@ -59,10 +59,17 @@ public class RpcSystem { }); } - private void startAcceptLoop() { + public void runOnce() { + for (var state: this.connections.values()) { + state.pollOnce().join(); + return; + } + } + + public void start() { this.network.accept() .thenAccept(this::accept) - .thenRunAsync(this::startAcceptLoop); + .thenRunAsync(this::start); } public static diff --git a/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java b/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java index f42f5b2..d1d3159 100644 --- a/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java +++ b/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java @@ -1,5 +1,6 @@ package org.capnproto; +import java.io.IOException; import java.nio.channels.AsynchronousByteChannel; import java.util.concurrent.CompletableFuture; @@ -35,4 +36,11 @@ public class TwoPartyClient { CompletableFuture onDisconnect() { return this.network.onDisconnect(); } + + public CompletableFuture runUntil(CompletableFuture done) { + while (!done.isDone()) { + this.rpcSystem.runOnce(); + } + return done; + } } diff --git a/runtime-rpc/src/main/java/org/capnproto/TwoPartyServer.java b/runtime-rpc/src/main/java/org/capnproto/TwoPartyServer.java index b51b752..9bab21f 100644 --- a/runtime-rpc/src/main/java/org/capnproto/TwoPartyServer.java +++ b/runtime-rpc/src/main/java/org/capnproto/TwoPartyServer.java @@ -1,8 +1,6 @@ package org.capnproto; -import java.nio.channels.AsynchronousServerSocketChannel; -import java.nio.channels.AsynchronousSocketChannel; -import java.nio.channels.CompletionHandler; +import java.nio.channels.*; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -10,41 +8,18 @@ import java.util.concurrent.CompletableFuture; public class TwoPartyServer { private class AcceptedConnection { - final AsynchronousSocketChannel connection; - final TwoPartyVatNetwork network; - final RpcSystem rpcSystem; + private final AsynchronousByteChannel connection; + private final TwoPartyVatNetwork network; + private final RpcSystem rpcSystem; - AcceptedConnection(Capability.Client bootstrapInterface, AsynchronousSocketChannel connection) { + AcceptedConnection(Capability.Client bootstrapInterface, AsynchronousByteChannel connection) { this.connection = connection; this.network = new TwoPartyVatNetwork(this.connection, RpcTwoPartyProtocol.Side.SERVER); this.rpcSystem = new RpcSystem<>(network, bootstrapInterface); + this.rpcSystem.start(); } } - class ConnectionReceiver { - final AsynchronousServerSocketChannel listener; - - ConnectionReceiver(AsynchronousServerSocketChannel listener) { - this.listener = listener; - } - - CompletableFuture accept() { - CompletableFuture result = new CompletableFuture<>(); - this.listener.accept(null, new CompletionHandler<>() { - @Override - public void completed(AsynchronousSocketChannel channel, Object attachment) { - result.complete(channel); - } - - @Override - public void failed(Throwable exc, Object attachment) { - result.completeExceptionally(exc); - } - }); - return result.copy(); - } - } - private final Capability.Client bootstrapInterface; private final List connections = new ArrayList<>(); @@ -65,14 +40,20 @@ public class TwoPartyServer { } public CompletableFuture listen(AsynchronousServerSocketChannel listener) { - return this.listen(wrapListenSocket(listener)); - } + var result = new CompletableFuture(); + listener.accept(null, new CompletionHandler<>() { + @Override + public void completed(AsynchronousSocketChannel channel, Object attachment) { + accept(channel); + result.complete(null); + } - CompletableFuture listen(ConnectionReceiver listener) { - return listener.accept().thenCompose(channel -> { - this.accept(channel); - return this.listen(listener); + @Override + public void failed(Throwable exc, Object attachment) { + result.completeExceptionally(exc); + } }); + return result.thenCompose(void_ -> this.listen(listener)); } CompletableFuture drain() { @@ -82,8 +63,4 @@ public class TwoPartyServer { } return loop; } - - ConnectionReceiver wrapListenSocket(AsynchronousServerSocketChannel channel) { - return new ConnectionReceiver(channel); - } } diff --git a/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java b/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java index 0ae1ecd..965d812 100644 --- a/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java +++ b/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java @@ -9,17 +9,12 @@ public class TwoPartyVatNetwork implements VatNetwork, VatNetwork.Connection { - public interface MessageTap { - void incoming(IncomingRpcMessage message, RpcTwoPartyProtocol.Side side); - } - private CompletableFuture previousWrite = CompletableFuture.completedFuture(null); private final CompletableFuture disconnectPromise = new CompletableFuture<>(); private final AsynchronousByteChannel channel; private final RpcTwoPartyProtocol.Side side; private final MessageBuilder peerVatId = new MessageBuilder(4); private boolean accepted; - private MessageTap tap; public TwoPartyVatNetwork(AsynchronousByteChannel channel, RpcTwoPartyProtocol.Side side) { this.channel = channel; @@ -65,26 +60,9 @@ public class TwoPartyVatNetwork @Override public CompletableFuture receiveIncomingMessage() { - var message = Serialize.readAsync(channel) + return Serialize.readAsync(channel) .thenApply(reader -> (IncomingRpcMessage) new IncomingMessage(reader)) .exceptionally(exc -> null); - - // send to message tap - if (this.tap != null) { - message = message.whenComplete((msg, exc) -> { - if (this.tap == null || msg == null) { - return; - } - - var side = this.side == RpcTwoPartyProtocol.Side.CLIENT - ? RpcTwoPartyProtocol.Side.SERVER - : RpcTwoPartyProtocol.Side.CLIENT; - - this.tap.incoming(msg, side); - }); - } - - return message; } @Override @@ -109,10 +87,6 @@ public class TwoPartyVatNetwork return side; } - public void setTap(MessageTap tap) { - this.tap = tap; - } - public Connection asConnection() { return this; } @@ -120,8 +94,7 @@ public class TwoPartyVatNetwork public CompletableFuture onDisconnect() { return this.disconnectPromise.copy(); } - - + public CompletableFuture> accept() { if (side == RpcTwoPartyProtocol.Side.SERVER & !accepted) { accepted = true; diff --git a/runtime-rpc/src/test/java/org/capnproto/RpcStateTest.java b/runtime-rpc/src/test/java/org/capnproto/RpcStateTest.java deleted file mode 100644 index 258f7da..0000000 --- a/runtime-rpc/src/test/java/org/capnproto/RpcStateTest.java +++ /dev/null @@ -1,161 +0,0 @@ -package org.capnproto; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import java.io.IOException; -import java.util.ArrayDeque; -import java.util.Queue; -import java.util.concurrent.CompletableFuture; - -public class RpcStateTest { - - class TestConnection implements VatNetwork.Connection { - - private CompletableFuture nextIncomingMessage = new CompletableFuture<>(); - private final CompletableFuture disconnect = new CompletableFuture<>(); - - @Override - public OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize) { - var message = new MessageBuilder(); - - return new OutgoingRpcMessage() { - @Override - public AnyPointer.Builder getBody() { - return message.getRoot(AnyPointer.factory); - } - - @Override - public void send() { - sent.add(this); - var msg = new IncomingRpcMessage() { - @Override - public AnyPointer.Reader getBody() { - return message.getRoot(AnyPointer.factory).asReader(); - } - }; - - if (nextIncomingMessage.isDone()) { - nextIncomingMessage = CompletableFuture.completedFuture(msg); - } - else { - nextIncomingMessage.complete(msg); - } - } - - @Override - public int sizeInWords() { - return 0; - } - }; - } - - @Override - public CompletableFuture receiveIncomingMessage() { - return this.nextIncomingMessage; - } - - @Override - public CompletableFuture shutdown() { - this.disconnect.complete(null); - return this.disconnect.thenRun(() -> {}); - } - - @Override - public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() { - return null; - } - - @Override - public void close() { - } - } - - TestConnection connection; - Capability.Client bootstrapInterface; - RpcState rpc; - final Queue sent = new ArrayDeque<>(); - - @Before - public void setUp() { - this.connection = new TestConnection(); - this.bootstrapInterface = new Capability.Client(Capability.newNullCap()); - var bootstrapFactory = new BootstrapFactory() { - @Override - public Capability.Client createFor(RpcTwoPartyProtocol.VatId.Reader clientId) { - return bootstrapInterface; - } - }; - - this.rpc = new RpcState<>(bootstrapFactory, connection, connection.disconnect); - } - - @After - public void tearDown() { - this.connection = null; - this.rpc = null; - this.sent.clear(); - } -/* - @Test - public void handleUnimplemented() { - var msg = this.connection.newOutgoingMessage(0); - var root = msg.getBody().initAs(RpcProtocol.Message.factory).initUnimplemented(); - var resolve = root.initResolve(); - RpcState.FromException(new Exception("foo"), resolve.initException()); - msg.send(); - Assert.assertFalse(sent.isEmpty()); - } -*/ - @Test - public void handleAbort() { - var msg = this.connection.newOutgoingMessage(0); - var builder = msg.getBody().initAs(RpcProtocol.Message.factory); - RpcState.FromException(RpcException.failed("Test abort"), builder.initAbort()); - msg.send(); - } - - @Test - public void handleBootstrap() { - var msg = this.connection.newOutgoingMessage(0); - var bootstrap = msg.getBody().initAs(RpcProtocol.Message.factory).initBootstrap(); - bootstrap.setQuestionId(0); - msg.send(); - Assert.assertEquals(2, sent.size()); - - sent.remove(); // bootstrap - var reply = sent.remove(); // return - - var rpcMsg = reply.getBody().getAs(RpcProtocol.Message.factory); - Assert.assertEquals(RpcProtocol.Message.Which.RETURN, rpcMsg.which()); - var ret = rpcMsg.getReturn(); - Assert.assertEquals(ret.getAnswerId(), 0); - Assert.assertEquals(RpcProtocol.Return.Which.RESULTS, ret.which()); - var results = ret.getResults(); - Assert.assertEquals(results.getCapTable().size(), 1); // got a capability! - Assert.assertTrue(results.hasContent()); - } - - @Test - public void handleCall() { - } - - @Test - public void handleReturn() { - } - - @Test - public void handleFinish() { - } - - @Test - public void handleResolve() { - } - - @Test - public void handleDisembargo() { - } - -} \ No newline at end of file diff --git a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java index f66a1e1..a8aa987 100644 --- a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java +++ b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java @@ -31,7 +31,6 @@ import java.util.Map; import java.util.Queue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; -import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; public class RpcTest { @@ -268,6 +267,7 @@ public class RpcTest { this.serverNetwork = this.network.add("server"); this.rpcClient = RpcSystem.makeRpcClient(this.clientNetwork); this.rpcServer = RpcSystem.makeRpcServer(this.serverNetwork, bootstrapFactory); + this.rpcServer.start(); } Capability.Client connect(Test.TestSturdyRefObjectId.Tag tag) { @@ -278,6 +278,13 @@ public class RpcTest { ref.getObjectId().initAs(Test.TestSturdyRefObjectId.factory).setTag(tag); return rpcClient.bootstrap(ref.asReader()); } + + public CompletableFuture runUntil(CompletableFuture done) { + while (!done.isDone()) { + this.rpcClient.runOnce(); + } + return done; + } } static BootstrapFactory bootstrapFactory = new BootstrapFactory<>() { @@ -321,7 +328,6 @@ public class RpcTest { this.context = null; } - @org.junit.Test public void testBasic() { var client = new Test.TestInterface.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_INTERFACE)); @@ -343,12 +349,15 @@ public class RpcTest { RpcTestUtil.initTestMessage(request2.getParams().initS()); var promise2 = request2.send(); - var response1 = promise1.join(); + var response1 = this.context.runUntil(promise1).join(); Assert.assertEquals("foo", response1.getX().toString()); - var response2 = promise2.join(); - promise3.join(); + while (!promise2.isDone()) { + this.context.rpcClient.runOnce(); + } + var response2 = this.context.runUntil(promise2).join(); + this.context.runUntil(promise3).join(); Assert.assertTrue(ref.barFailed); } @@ -376,10 +385,10 @@ public class RpcTest { //Assert.assertEquals(0, chainedCallCount.value()); - var response = pipelinePromise.join(); + var response = this.context.runUntil(pipelinePromise).join(); Assert.assertEquals("bar", response.getX().toString()); - var response2 = pipelinePromise2.join(); + var response2 = this.context.runUntil(pipelinePromise2).join(); RpcTestUtil.checkTestMessage(response2); Assert.assertEquals(1, chainedCallCount.value()); @@ -389,15 +398,15 @@ public class RpcTest { public void testRelease() { var client = new Test.TestMoreStuff.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_MORE_STUFF)); - var handle1 = client.getHandleRequest().send().join().getHandle(); + var handle1 = this.context.runUntil(client.getHandleRequest().send()).join().getHandle(); var promise = client.getHandleRequest().send(); - var handle2 = promise.join().getHandle(); + var handle2 = this.context.runUntil(promise).join().getHandle(); handle1 = null; handle2 = null; System.gc(); - client.echoRequest().send().join(); + this.context.runUntil(client.echoRequest().send()).join(); } @org.junit.Test @@ -421,15 +430,15 @@ public class RpcTest { // Make sure getCap() has been called on the server side by sending another call and waiting // for it. - Assert.assertEquals(2, client.getCallSequenceRequest().send().join().getN()); + Assert.assertEquals(2, this.context.runUntil(client.getCallSequenceRequest().send()).join().getN()); //Assert.assertEquals(3, context.restorer.callCount); // OK, now fulfill the local promise. paf.complete(new Test.TestInterface.Client(new RpcTestUtil.TestInterfaceImpl(chainedCallCount))); // We should now be able to wait for getCap() to finish. - Assert.assertEquals("bar", promise.join().getS().toString()); - Assert.assertEquals("bar", promise2.join().getS().toString()); + Assert.assertEquals("bar", this.context.runUntil(promise).join().getS().toString()); + Assert.assertEquals("bar", this.context.runUntil(promise2).join().getS().toString()); //Assert.assertEquals(3, context.restorer.callCount); Assert.assertEquals(2, chainedCallCount.value()); @@ -447,16 +456,16 @@ public class RpcTest { var promise = request.send(); var dependentCall0 = promise.getC().getCallSequenceRequest().send(); - var response = promise.join(); + var response = this.context.runUntil(promise).join(); Assert.assertEquals(456, response.getI()); var dependentCall1 = promise.getC().getCallSequenceRequest().send(); - Assert.assertEquals(0, dependentCall0.join().getN()); - Assert.assertEquals(1, dependentCall1.join().getN()); + Assert.assertEquals(0, this.context.runUntil(dependentCall0).join().getN()); + Assert.assertEquals(1, this.context.runUntil(dependentCall1).join().getN()); var dependentCall2 = response.getC().getCallSequenceRequest().send(); - Assert.assertEquals(2, dependentCall2.join().getN()); + Assert.assertEquals(2, this.context.runUntil(dependentCall2).join().getN()); Assert.assertEquals(1, calleeCallCount.value()); } @@ -482,26 +491,26 @@ public class RpcTest { var call0 = getCallSequence(pipeline, 0); var call1 = getCallSequence(pipeline, 1); - earlyCall.join(); + this.context.runUntil(earlyCall).join(); var call2 = getCallSequence(pipeline, 2); - var resolved = echo.join().getCap(); + var resolved = this.context.runUntil(echo).join().getCap(); var call3 = getCallSequence(pipeline, 3); var call4 = getCallSequence(pipeline, 4); var call5 = getCallSequence(pipeline, 5); - Assert.assertEquals(0, call0.join().getN()); - Assert.assertEquals(1, call1.join().getN()); - Assert.assertEquals(2, call2.join().getN()); - Assert.assertEquals(3, call3.join().getN()); - Assert.assertEquals(4, call4.join().getN()); - Assert.assertEquals(5, call5.join().getN()); + Assert.assertEquals(0, this.context.runUntil(call0).join().getN()); + Assert.assertEquals(1, this.context.runUntil(call1).join().getN()); + Assert.assertEquals(2, this.context.runUntil(call2).join().getN()); + Assert.assertEquals(3, this.context.runUntil(call3).join().getN()); + Assert.assertEquals(4, this.context.runUntil(call4).join().getN()); + Assert.assertEquals(5, this.context.runUntil(call5).join().getN()); } @org.junit.Test - public void testCallBrokenPromise() throws ExecutionException, InterruptedException { + public void testCallBrokenPromise() { var client = new Test.TestMoreStuff.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_MORE_STUFF)); var paf = new CompletableFuture(); @@ -509,7 +518,7 @@ public class RpcTest { { var req = client.holdRequest(); req.getParams().setCap(paf); - req.send().join(); + this.context.runUntil(req.send()).join(); } AtomicBoolean returned = new AtomicBoolean(false); @@ -524,10 +533,11 @@ public class RpcTest { Assert.assertFalse(returned.get()); paf.completeExceptionally(new Exception("foo")); + this.context.runUntil(req); Assert.assertTrue(returned.get()); // Verify that we are still connected - getCallSequence(client, 1).get(); + this.context.runUntil(getCallSequence(client, 1)).join(); } @org.junit.Test @@ -581,24 +591,24 @@ public class RpcTest { var call0 = getCallSequence(pipeline, 0); var call1 = getCallSequence(pipeline, 1); - earlyCall.join(); + this.context.runUntil(earlyCall).join(); var call2 = getCallSequence(pipeline, 2); - var resolved = echo.join().getCap(); + var resolved = this.context.runUntil(echo).join().getCap(); var call3 = getCallSequence(pipeline, 3); var call4 = getCallSequence(pipeline, 4); var call5 = getCallSequence(pipeline, 5); - Assert.assertEquals(0, call0.join().getN()); - Assert.assertEquals(1, call1.join().getN()); - Assert.assertEquals(2, call2.join().getN()); - Assert.assertEquals(3, call3.join().getN()); - Assert.assertEquals(4, call4.join().getN()); - Assert.assertEquals(5, call5.join().getN()); + Assert.assertEquals(0, this.context.runUntil(call0).join().getN()); + Assert.assertEquals(1, this.context.runUntil(call1).join().getN()); + Assert.assertEquals(2, this.context.runUntil(call2).join().getN()); + Assert.assertEquals(3, this.context.runUntil(call3).join().getN()); + Assert.assertEquals(4, this.context.runUntil(call4).join().getN()); + Assert.assertEquals(5, this.context.runUntil(call5).join().getN()); - int unwrappedAt = unwrap.join(); + int unwrappedAt = this.context.runUntil(unwrap).join(); Assert.assertTrue(unwrappedAt >= 0); } } diff --git a/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java b/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java index 03fad39..2586e22 100644 --- a/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java +++ b/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java @@ -7,39 +7,44 @@ import org.junit.Before; import java.io.IOException; import java.nio.channels.AsynchronousByteChannel; +import java.nio.channels.AsynchronousChannelGroup; import java.nio.channels.AsynchronousServerSocketChannel; import java.nio.channels.AsynchronousSocketChannel; import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.function.Consumer; -@SuppressWarnings({"OverlyCoupledMethod", "OverlyLongMethod"}) public class TwoPartyTest { static final class PipeThread { Thread thread; - AsynchronousByteChannel channel; + AsynchronousSocketChannel channel; - static PipeThread newPipeThread(Consumer startFunc) throws Exception { - var pipeThread = new PipeThread(); - var serverAcceptSocket = AsynchronousServerSocketChannel.open(); - serverAcceptSocket.bind(null); - var clientSocket = AsynchronousSocketChannel.open(); + } - pipeThread.thread = new Thread(() -> { - try { - var serverSocket = serverAcceptSocket.accept().get(); - startFunc.accept(serverSocket); - } catch (InterruptedException | ExecutionException e) { - e.printStackTrace(); - } - }); - pipeThread.thread.start(); - pipeThread.thread.setName("TwoPartyTest server"); + private AsynchronousChannelGroup group; - clientSocket.connect(serverAcceptSocket.getLocalAddress()).get(); - pipeThread.channel = clientSocket; - return pipeThread; - } + PipeThread newPipeThread(Consumer startFunc) throws Exception { + var pipeThread = new PipeThread(); + var serverAcceptSocket = AsynchronousServerSocketChannel.open(this.group); + serverAcceptSocket.bind(null); + var clientSocket = AsynchronousSocketChannel.open(); + + pipeThread.thread = new Thread(() -> { + try { + var serverSocket = serverAcceptSocket.accept().get(); + startFunc.accept(serverSocket); + } catch (InterruptedException | ExecutionException exc) { + exc.printStackTrace(); + } + }); + pipeThread.thread.start(); + pipeThread.thread.setName("TwoPartyTest server"); + + clientSocket.connect(serverAcceptSocket.getLocalAddress()).get(); + pipeThread.channel = clientSocket; + return pipeThread; } PipeThread runServer(Capability.Server bootstrapInterface) throws Exception { @@ -47,19 +52,22 @@ public class TwoPartyTest { } PipeThread runServer(Capability.Client bootstrapInterface) throws Exception { - return PipeThread.newPipeThread(channel -> { + return newPipeThread(channel -> { var network = new TwoPartyVatNetwork(channel, RpcTwoPartyProtocol.Side.SERVER); var system = new RpcSystem<>(network, bootstrapInterface); + system.start(); network.onDisconnect().join(); }); } @Before - public void setUp() { + public void setUp() throws IOException { + this.group = AsynchronousChannelGroup.withThreadPool(Executors.newFixedThreadPool(5)); } @After public void tearDown() { + this.group.shutdown(); } @org.junit.Test @@ -68,7 +76,7 @@ public class TwoPartyTest { var rpcClient = new TwoPartyClient(pipe.channel); var client = rpcClient.bootstrap(); var resolved = client.whenResolved(); - resolved.get(); + rpcClient.runUntil(resolved).join(); } @org.junit.Test @@ -93,11 +101,11 @@ public class TwoPartyTest { .thenAccept(results -> Assert.fail("Expected bar() to fail")) .exceptionally(exc -> null); - var response1 = promise1.join(); + var response1 = rpcClient.runUntil(promise1).join(); Assert.assertEquals("foo", response1.getX().toString()); - promise2.join(); - promise3.join(); + rpcClient.runUntil(promise2).join(); + rpcClient.runUntil(promise3).join(); Assert.assertEquals(2, callCount.value()); } @@ -136,10 +144,10 @@ public class TwoPartyTest { //Assert.assertEquals(0, chainedCallCount.value()); - var response = pipelinePromise.join(); + var response = rpcClient.runUntil(pipelinePromise).join(); Assert.assertEquals("bar", response.getX().toString()); - var response2 = pipelinePromise2.join(); + var response2 = rpcClient.runUntil(pipelinePromise2).join(); RpcTestUtil.checkTestMessage(response2); Assert.assertEquals(1, chainedCallCount.value()); @@ -147,7 +155,7 @@ public class TwoPartyTest { // disconnect the client ((AsynchronousSocketChannel)pipe.channel).shutdownOutput(); - rpcClient.onDisconnect().join(); + rpcClient.runUntil(rpcClient.onDisconnect()).join(); { // Use the now-broken capability.