From 69a045deec25654181f02d6cf770349caf199416 Mon Sep 17 00:00:00 2001 From: Vaci Koblizek Date: Thu, 12 Nov 2020 22:13:48 +0000 Subject: [PATCH] make requests autoclosable and cleanup disconnection --- .../src/main/java/org/capnproto/RpcState.java | 82 +++++--- .../main/java/org/capnproto/RpcSystem.java | 22 +- .../java/org/capnproto/TwoPartyClient.java | 10 +- .../org/capnproto/TwoPartyVatNetwork.java | 27 ++- .../main/java/org/capnproto/VatNetwork.java | 8 +- .../test/java/org/capnproto/RpcStateTest.java | 90 ++++---- .../src/test/java/org/capnproto/RpcTest.java | 13 +- .../test/java/org/capnproto/TwoPartyTest.java | 196 +++++++++++------- .../main/java/org/capnproto/Capability.java | 16 +- .../main/java/org/capnproto/PipelineHook.java | 7 +- 10 files changed, 276 insertions(+), 195 deletions(-) diff --git a/runtime-rpc/src/main/java/org/capnproto/RpcState.java b/runtime-rpc/src/main/java/org/capnproto/RpcState.java index 49ab7db..537328f 100644 --- a/runtime-rpc/src/main/java/org/capnproto/RpcState.java +++ b/runtime-rpc/src/main/java/org/capnproto/RpcState.java @@ -3,7 +3,6 @@ package org.capnproto; import java.io.IOException; import java.io.PrintWriter; import java.io.StringWriter; -import java.lang.ref.Reference; import java.lang.ref.ReferenceQueue; import java.lang.ref.WeakReference; import java.util.*; @@ -35,6 +34,16 @@ final class RpcState { = RpcProtocol.CapDescriptor.factory.structSize().total() + RpcProtocol.PromisedAnswer.factory.structSize().total(); + static class DisconnectInfo { + + final CompletableFuture shutdownPromise; + // Task which is working on sending an abort message and cleanly ending the connection. + + DisconnectInfo(CompletableFuture shutdownPromise) { + this.shutdownPromise = shutdownPromise; + } + } + private final class QuestionDisposer { final int id; @@ -224,7 +233,6 @@ final class RpcState { final static class Embargo { final int id; final CompletableFuture disembargo = new CompletableFuture<>(); - Embargo(int id) { this.id = id; } @@ -263,7 +271,7 @@ final class RpcState { private final Map exportsByCap = new HashMap<>(); private final BootstrapFactory bootstrapFactory; private final VatNetwork.Connection connection; - private final CompletableFuture onDisconnect; + private final CompletableFuture disconnectFulfiller; private Throwable disconnected = null; private CompletableFuture messageReady = CompletableFuture.completedFuture(null); private final CompletableFuture messageLoop = new CompletableFuture<>(); @@ -273,10 +281,10 @@ final class RpcState { RpcState(BootstrapFactory bootstrapFactory, VatNetwork.Connection connection, - CompletableFuture onDisconnect) { + CompletableFuture disconnectFulfiller) { this.bootstrapFactory = bootstrapFactory; this.connection = connection; - this.onDisconnect = onDisconnect; + this.disconnectFulfiller = disconnectFulfiller; startMessageLoop(); } @@ -284,13 +292,10 @@ final class RpcState { return this.messageLoop; } - public CompletableFuture onDisconnect() { - return this.messageLoop; - } - - CompletableFuture disconnect(Throwable exc) { + void disconnect(Throwable exc) { if (isDisconnected()) { - return CompletableFuture.failedFuture(this.disconnected); + // Already disconnected. + return; } var networkExc = RpcException.disconnected(exc.getMessage()); @@ -334,6 +339,7 @@ final class RpcState { } } + // Send an abort message, but ignore failure. try { int sizeHint = messageSizeHint() + exceptionSizeHint(exc); var message = this.connection.newOutgoingMessage(sizeHint); @@ -344,25 +350,31 @@ final class RpcState { catch (Exception ignored) { } - var onShutdown = this.connection.shutdown().handle((x, ioExc) -> { - if (ioExc == null) { - return CompletableFuture.completedFuture(null); - } + var shutdownPromise = this.connection.shutdown() + .exceptionallyCompose(ioExc -> { - // TODO IOException? assert !(ioExc instanceof IOException); if (ioExc instanceof RpcException) { var rpcExc = (RpcException)exc; + + // Don't report disconnects as an error if (rpcExc.getType() == RpcException.Type.DISCONNECTED) { return CompletableFuture.completedFuture(null); } } + return CompletableFuture.failedFuture(ioExc); }); this.disconnected = networkExc; - return onShutdown.thenCompose(x -> CompletableFuture.failedFuture(networkExc)); + this.disconnectFulfiller.complete(new DisconnectInfo(shutdownPromise)); + + for (var pipeline: pipelinesToRelease) { + if (pipeline instanceof RpcState.RpcPipeline) { + ((RpcPipeline) pipeline).redirectLater.completeExceptionally(networkExc); + } + } } final boolean isDisconnected() { @@ -389,12 +401,7 @@ final class RpcState { ClientHook restore() { var question = questions.next(); question.setAwaitingReturn(true); - - // Run the message loop until the boostrap promise is resolved. var promise = new CompletableFuture(); - var loop = CompletableFuture.anyOf( - getMessageLoop(), promise).thenCompose(x -> promise); - int sizeHint = messageSizeHint(RpcProtocol.Bootstrap.factory); var message = connection.newOutgoingMessage(sizeHint); var builder = message.getBody().initAs(RpcProtocol.Message.factory).initBootstrap(); @@ -413,6 +420,7 @@ final class RpcState { var messageReader = this.connection.receiveIncomingMessage() .thenAccept(message -> { if (message == null) { + this.disconnect(RpcException.disconnected("Peer disconnected")); this.messageLoop.complete(null); return; } @@ -423,11 +431,12 @@ final class RpcState { // or internal RpcState is bad. this.disconnect(rpcExc); } - this.cleanupImports(); - this.cleanupQuestions(); }); - messageReader.thenRunAsync(this::startMessageLoop); + messageReader.thenRunAsync(this::startMessageLoop).exceptionallyCompose(exc -> { + assert exc == null: "Exception in startMessageLoop!"; + return CompletableFuture.failedFuture(exc); + }); } private void handleMessage(IncomingRpcMessage message) throws RpcException { @@ -470,6 +479,9 @@ final class RpcState { } break; } + + this.cleanupImports(); + this.cleanupQuestions(); } void handleUnimplemented(RpcProtocol.Message.Reader message) { @@ -1427,7 +1439,6 @@ final class RpcState { this.responseSent = false; sendErrorReturn(exc); } - cleanupAnswerTable(exports); } @@ -1512,6 +1523,7 @@ final class RpcState { RpcPipeline(Question question, CompletableFuture redirectLater) { this.question = question; + assert redirectLater != null; this.redirectLater = redirectLater; } @@ -1542,6 +1554,11 @@ final class RpcState { return new PromiseClient(pipelineClient, resolutionPromise, null); }); } + + @Override + public void close() { + this.question.finish(); + } } abstract class RpcClient implements ClientHook { @@ -1787,11 +1804,11 @@ final class RpcState { this.cap = initial; this.importId = importId; eventual.whenComplete((resolution, exc) -> { - if (exc != null) { - resolve(Capability.newBrokenCap(exc)); + if (exc == null) { + resolve(resolution); } else { - resolve(resolution); + resolve(Capability.newBrokenCap(exc)); } }); } @@ -1842,6 +1859,10 @@ final class RpcState { // TODO Flow control if (resolutionType == ResolutionType.REFLECTED && receivedCall && !isDisconnected()) { + // The new capability is hosted locally, not on the remote machine. And, we had made calls + // to the promise. We need to make sure those calls echo back to us before we allow new + // calls to go directly to the local capability, so we need to set a local embargo and send + // a `Disembargo` to echo through the peer. int sizeHint = messageSizeHint(RpcProtocol.Disembargo.factory); var message = connection.newOutgoingMessage(sizeHint); var disembargo = message.getBody().initAs(RpcProtocol.Message.factory).initDisembargo(); @@ -1852,7 +1873,8 @@ final class RpcState { disembargo.getContext().setSenderLoopback(embargo.id); final ClientHook finalReplacement = replacement; - var embargoPromise = embargo.disembargo.thenApply(x -> finalReplacement); + var embargoPromise = embargo.disembargo.thenApply( + void_ -> finalReplacement); replacement = Capability.newLocalPromiseClient(embargoPromise); message.send(); } diff --git a/runtime-rpc/src/main/java/org/capnproto/RpcSystem.java b/runtime-rpc/src/main/java/org/capnproto/RpcSystem.java index 9ab49b1..e7c5b0c 100644 --- a/runtime-rpc/src/main/java/org/capnproto/RpcSystem.java +++ b/runtime-rpc/src/main/java/org/capnproto/RpcSystem.java @@ -1,5 +1,6 @@ package org.capnproto; +import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -68,14 +69,21 @@ public class RpcSystem { } RpcState getConnectionState(VatNetwork.Connection connection) { + var state = this.connections.get(connection); + if (state == null) { + var onDisconnect = new CompletableFuture() + .whenComplete((info, exc) -> { + this.connections.remove(connection); + try { + connection.close(); + } catch (IOException ignored) { + } + }); - var onDisconnect = new CompletableFuture>() - .thenAccept(lostConnection -> { - this.connections.remove(lostConnection); - }); - - return connections.computeIfAbsent(connection, key -> - new RpcState(this.bootstrapFactory, connection, onDisconnect)); + state = new RpcState<>(this.bootstrapFactory, connection, onDisconnect); + this.connections.put(connection, state); + } + return state; } public void accept(VatNetwork.Connection connection) { diff --git a/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java b/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java index 19eb970..d6056c6 100644 --- a/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java +++ b/runtime-rpc/src/main/java/org/capnproto/TwoPartyClient.java @@ -20,7 +20,7 @@ public class TwoPartyClient { Capability.Client bootstrapInterface, RpcTwoPartyProtocol.Side side) { this.network = new TwoPartyVatNetwork(channel, side); - this.rpcSystem = new RpcSystem(network, bootstrapInterface); + this.rpcSystem = new RpcSystem<>(network, bootstrapInterface); } public Capability.Client bootstrap() { @@ -31,12 +31,4 @@ public class TwoPartyClient { : RpcTwoPartyProtocol.Side.CLIENT); return rpcSystem.bootstrap(vatId.asReader()); } - - public TwoPartyVatNetwork getNetwork() { - return this.network; - } - - public CompletableFuture onDisconnect() { - return this.network.onDisconnect(); - } } diff --git a/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java b/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java index 1576aa6..7929286 100644 --- a/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java +++ b/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java @@ -1,5 +1,6 @@ package org.capnproto; +import java.io.IOException; import java.nio.channels.AsynchronousSocketChannel; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -18,7 +19,7 @@ public class TwoPartyVatNetwork } private CompletableFuture previousWrite = CompletableFuture.completedFuture(null); - private final CompletableFuture peerDisconnected = new CompletableFuture<>(); + private final CompletableFuture disconnectPromise = new CompletableFuture<>(); private final AsynchronousSocketChannel channel; private final RpcTwoPartyProtocol.Side side; private final MessageBuilder peerVatId = new MessageBuilder(4); @@ -34,6 +35,12 @@ public class TwoPartyVatNetwork : RpcTwoPartyProtocol.Side.CLIENT); } + @Override + public void close() throws IOException { + this.channel.close(); + this.disconnectPromise.complete(null); + } + public RpcTwoPartyProtocol.Side getSide() { return side; } @@ -46,6 +53,10 @@ public class TwoPartyVatNetwork return this; } + public CompletableFuture onDisconnect() { + return this.disconnectPromise.copy(); + } + @Override public Connection connect(RpcTwoPartyProtocol.VatId.Reader vatId) { return vatId.getSide() != side @@ -59,7 +70,7 @@ public class TwoPartyVatNetwork return CompletableFuture.completedFuture(this.asConnection()); } else { - // never /home/vaci/g/capnproto-java/compilercompletes + // never completes return new CompletableFuture<>(); } } @@ -97,20 +108,20 @@ public class TwoPartyVatNetwork return message; } - @Override - public CompletableFuture onDisconnect() { - return this.peerDisconnected.copy(); - } - @Override public CompletableFuture shutdown() { - return this.previousWrite.whenComplete((x, exc) -> { + assert this.previousWrite != null: "Already shut down"; + + var result = this.previousWrite.thenRun(() -> { try { this.channel.shutdownOutput(); } catch (Exception ioExc) { } }); + + this.previousWrite = null; + return result; } final class OutgoingMessage implements OutgoingRpcMessage { diff --git a/runtime-rpc/src/main/java/org/capnproto/VatNetwork.java b/runtime-rpc/src/main/java/org/capnproto/VatNetwork.java index 1286c71..5660d6c 100644 --- a/runtime-rpc/src/main/java/org/capnproto/VatNetwork.java +++ b/runtime-rpc/src/main/java/org/capnproto/VatNetwork.java @@ -1,24 +1,22 @@ package org.capnproto; +import java.io.IOException; import java.util.concurrent.CompletableFuture; public interface VatNetwork { - interface Connection { + interface Connection extends AutoCloseable { default OutgoingRpcMessage newOutgoingMessage() { return newOutgoingMessage(0); } OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize); CompletableFuture receiveIncomingMessage(); - CompletableFuture onDisconnect(); CompletableFuture shutdown(); VatId getPeerVatId(); + void close() throws IOException; } CompletableFuture> baseAccept(); - - //FromPointerReader getVatIdFactory(); - Connection connect(VatId hostId); } diff --git a/runtime-rpc/src/test/java/org/capnproto/RpcStateTest.java b/runtime-rpc/src/test/java/org/capnproto/RpcStateTest.java index 7cb4959..6a2f475 100644 --- a/runtime-rpc/src/test/java/org/capnproto/RpcStateTest.java +++ b/runtime-rpc/src/test/java/org/capnproto/RpcStateTest.java @@ -5,30 +5,17 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import java.io.IOException; import java.util.ArrayDeque; import java.util.Queue; import java.util.concurrent.CompletableFuture; public class RpcStateTest { - class TestMessage implements IncomingRpcMessage { - - MessageBuilder builder = new MessageBuilder(); - - @Override - public AnyPointer.Reader getBody() { - return builder.getRoot(AnyPointer.factory).asReader(); - } - } - class TestConnection implements VatNetwork.Connection { private CompletableFuture nextIncomingMessage = new CompletableFuture<>(); - private final CompletableFuture disconnect = new CompletableFuture<>(); - - public void setNextIncomingMessage(IncomingRpcMessage message) { - this.nextIncomingMessage.complete(message); - } + private final CompletableFuture disconnect = new CompletableFuture<>(); @Override public OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize) { @@ -43,6 +30,19 @@ public class RpcStateTest { @Override public void send() { sent.add(this); + var msg = new IncomingRpcMessage() { + @Override + public AnyPointer.Reader getBody() { + return message.getRoot(AnyPointer.factory).asReader(); + } + }; + + if (nextIncomingMessage.isDone()) { + nextIncomingMessage = CompletableFuture.completedFuture(msg); + } + else { + nextIncomingMessage.complete(msg); + } } @Override @@ -54,24 +54,23 @@ public class RpcStateTest { @Override public CompletableFuture receiveIncomingMessage() { - return this.nextIncomingMessage; - } - - @Override - public CompletableFuture onDisconnect() { - return this.disconnect.copy(); + return this.nextIncomingMessage; } @Override public CompletableFuture shutdown() { this.disconnect.complete(null); - return this.disconnect.copy(); + return this.disconnect.thenRun(() -> {}); } @Override public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() { return null; } + + @Override + public void close() { + } } TestConnection connection; @@ -80,7 +79,7 @@ public class RpcStateTest { final Queue sent = new ArrayDeque<>(); @Before - public void setUp() throws Exception { + public void setUp() { this.connection = new TestConnection(); this.bootstrapInterface = new Capability.Client(Capability.newNullCap()); var bootstrapFactory = new BootstrapFactory() { @@ -95,45 +94,50 @@ public class RpcStateTest { } }; - this.rpc = new RpcState(bootstrapFactory, connection, connection.disconnect); + this.rpc = new RpcState<>(bootstrapFactory, connection, connection.disconnect); } @After - public void tearDown() throws Exception { + public void tearDown() { this.connection = null; this.rpc = null; this.sent.clear(); } - +/* @Test - public void handleUnimplemented() throws RpcException { - var msg = new TestMessage(); - msg.builder.getRoot(RpcProtocol.Message.factory).initUnimplemented(); - this.connection.setNextIncomingMessage(msg); + public void handleUnimplemented() { + var msg = this.connection.newOutgoingMessage(0); + var root = msg.getBody().initAs(RpcProtocol.Message.factory).initUnimplemented(); + var resolve = root.initResolve(); + RpcState.FromException(new Exception("foo"), resolve.initException()); + msg.send(); + Assert.assertFalse(sent.isEmpty()); } - +*/ @Test public void handleAbort() { - var msg = new TestMessage(); - var builder = msg.builder.getRoot(RpcProtocol.Message.factory); + var msg = this.connection.newOutgoingMessage(0); + var builder = msg.getBody().initAs(RpcProtocol.Message.factory); RpcState.FromException(RpcException.failed("Test abort"), builder.initAbort()); - this.connection.setNextIncomingMessage(msg); - //Assert.assertThrows(RpcException.class, () -> rpc.handleMessage(msg)); + msg.send(); } @Test - public void handleBootstrap() throws RpcException { - var msg = new TestMessage(); - var bootstrap = msg.builder.getRoot(RpcProtocol.Message.factory).initBootstrap(); + public void handleBootstrap() { + var msg = this.connection.newOutgoingMessage(0); + var bootstrap = msg.getBody().initAs(RpcProtocol.Message.factory).initBootstrap(); bootstrap.setQuestionId(0); - this.connection.setNextIncomingMessage(msg); - Assert.assertFalse(sent.isEmpty()); - var reply = sent.remove(); + msg.send(); + Assert.assertEquals(2, sent.size()); + + sent.remove(); // bootstrap + var reply = sent.remove(); // return + var rpcMsg = reply.getBody().getAs(RpcProtocol.Message.factory); - Assert.assertEquals(rpcMsg.which(), RpcProtocol.Message.Which.RETURN); + Assert.assertEquals(RpcProtocol.Message.Which.RETURN, rpcMsg.which()); var ret = rpcMsg.getReturn(); Assert.assertEquals(ret.getAnswerId(), 0); - Assert.assertEquals(ret.which(), RpcProtocol.Return.Which.RESULTS); + Assert.assertEquals(RpcProtocol.Return.Which.RESULTS, ret.which()); var results = ret.getResults(); Assert.assertEquals(results.getCapTable().size(), 1); // got a capability! Assert.assertTrue(results.hasContent()); diff --git a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java index f0a3ae4..89b918d 100644 --- a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java +++ b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java @@ -25,13 +25,10 @@ import org.capnproto.rpctest.Test; import org.junit.Assert; -import java.lang.ref.ReferenceQueue; -import java.lang.ref.WeakReference; import java.util.ArrayDeque; import java.util.HashMap; import java.util.Map; import java.util.Queue; -import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; @@ -156,11 +153,6 @@ public class RpcTest { } } - @Override - public CompletableFuture onDisconnect() { - return null; - } - @Override public CompletableFuture shutdown() { if (this.partner == null) { @@ -174,6 +166,10 @@ public class RpcTest { public Test.TestSturdyRef.Reader getPeerVatId() { return this.peerId; } + + @Override + public void close() { + } } final TestNetwork network; @@ -430,6 +426,7 @@ public class RpcTest { Assert.assertEquals(456, response.getI()); var dependentCall1 = promise.getC().getCallSequenceRequest().send(); + Assert.assertEquals(0, dependentCall0.join().getN()); Assert.assertEquals(1, dependentCall1.join().getN()); diff --git a/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java b/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java index 222007c..0102c0a 100644 --- a/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java +++ b/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java @@ -1,60 +1,15 @@ package org.capnproto; -/* -import org.capnproto.demo.Demo; +import org.capnproto.rpctest.*; import org.junit.After; import org.junit.Assert; import org.junit.Before; -import org.junit.Test; +import org.junit.function.ThrowingRunnable; import java.io.IOException; import java.nio.channels.AsynchronousServerSocketChannel; import java.nio.channels.AsynchronousSocketChannel; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeoutException; - -class TestCap0Impl extends Demo.TestCap0.Server { - - final Demo.TestCap1.Client testCap1a = new Demo.TestCap1.Client(new TestCap1Impl()); - final Demo.TestCap1.Client testCap1b = new Demo.TestCap1.Client(new TestCap1Impl()); - - public CompletableFuture testMethod0(CallContext ctx) { - var params = ctx.getParams(); - var results = ctx.getResults(); - results.setResult0(params.getParam0()); - ctx.releaseParams(); - return CompletableFuture.completedFuture(null); - } - - public CompletableFuture testMethod1(CallContext ctx) { - var params = ctx.getParams(); - var results = ctx.getResults(); - var res0 = results.getResult0(); - res0.setAs(Demo.TestCap1.factory, testCap1a); - var res1 = results.getResult1(); - res1.setAs(Demo.TestCap1.factory, testCap1b); - var res2 = results.getResult2(); - res2.setAs(Demo.TestCap1.factory, testCap1b); - return CompletableFuture.completedFuture(null); - } -} - -class TestCap1Impl extends Demo.TestCap1.Server { -} - -class Tap implements org.capnproto.TwoPartyVatNetwork.MessageTap { - - final RpcDumper dumper = new RpcDumper(); - - @Override - public void incoming(IncomingRpcMessage message, RpcTwoPartyProtocol.Side side) { - var text = this.dumper.dump(message.getBody().getAs(RpcProtocol.Message.factory), side); - if (text.length() > 0) { - System.out.println(text); - } - } -} public class TwoPartyTest { @@ -73,7 +28,8 @@ public class TwoPartyTest { return thread; } - AsynchronousServerSocketChannel serverSocket; + AsynchronousServerSocketChannel serverAcceptSocket; + AsynchronousSocketChannel serverSocket; AsynchronousSocketChannel clientSocket; TwoPartyClient client; org.capnproto.TwoPartyVatNetwork serverNetwork; @@ -81,17 +37,17 @@ public class TwoPartyTest { @Before public void setUp() throws Exception { - this.serverSocket = AsynchronousServerSocketChannel.open(); - this.serverSocket.bind(null); + this.serverAcceptSocket = AsynchronousServerSocketChannel.open(); + this.serverAcceptSocket.bind(null); this.clientSocket = AsynchronousSocketChannel.open(); - this.clientSocket.connect(this.serverSocket.getLocalAddress()).get(); + this.clientSocket.connect(this.serverAcceptSocket.getLocalAddress()).get(); this.client = new TwoPartyClient(clientSocket); - this.client.getNetwork().setTap(new Tap()); + //this.client.getNetwork().setTap(new Tap()); - var socket = serverSocket.accept().get(); - this.serverNetwork = new org.capnproto.TwoPartyVatNetwork(socket, RpcTwoPartyProtocol.Side.SERVER); - this.serverNetwork.setTap(new Tap()); + this.serverSocket = serverAcceptSocket.accept().get(); + this.serverNetwork = new org.capnproto.TwoPartyVatNetwork(this.serverSocket, RpcTwoPartyProtocol.Side.SERVER); + //this.serverNetwork.setTap(new Tap()); //this.serverNetwork.dumper.addSchema(Demo.TestCap1); this.serverThread = runServer(this.serverNetwork); } @@ -100,36 +56,128 @@ public class TwoPartyTest { public void tearDown() throws Exception { this.clientSocket.close(); this.serverSocket.close(); + this.serverAcceptSocket.close(); this.serverThread.join(); this.client = null; } - @Test + @org.junit.Test public void testNullCap() throws ExecutionException, InterruptedException { var server = new RpcSystem<>(this.serverNetwork, new Capability.Client()); var cap = this.client.bootstrap(); - var resolved = cap.whenResolved().toCompletableFuture(); + var resolved = cap.whenResolved(); resolved.get(); } - @Test - public void testBasic() throws ExecutionException, InterruptedException, IOException { - var server = new RpcSystem<>(this.serverNetwork, new TestCap0Impl()); + @org.junit.Test + public void testBasic() throws InterruptedException, IOException { - var demo = new Demo.TestCap0.Client(this.client.bootstrap()); - var request = demo.testMethod0Request(); - var params = request.getParams(); - params.setParam0(4321); - var response = request.send(); - response.get(); - Assert.assertTrue(response.isDone()); - var results = response.get(); - Assert.assertEquals(params.getParam0(), results.getResult0()); + var callCount = new Counter(); + var server = new RpcSystem<>(this.serverNetwork, new RpcTestUtil.TestInterfaceImpl(callCount)); + + var client = new Test.TestInterface.Client(this.client.bootstrap()); + var request1 = client.fooRequest(); + request1.getParams().setI(123); + request1.getParams().setJ(true); + + var promise1 = request1.send(); + + var request2 = client.bazRequest(); + RpcTestUtil.initTestMessage(request2.getParams().initS()); + var promise2 = request2.send(); + + boolean barFailed = false; + var request3 = client.barRequest(); + var promise3 = request3.send() + .thenAccept(results -> Assert.fail("Expected bar() to fail")) + .exceptionally(exc -> null); + + var response1 = promise1.join(); + Assert.assertEquals("foo", response1.getX().toString()); + + promise2.join(); + promise3.join(); + + Assert.assertEquals(2, callCount.value()); this.clientSocket.shutdownOutput(); serverThread.join(); } - @Test + @org.junit.Test + public void testDisconnect() throws IOException { + this.serverSocket.shutdownOutput(); + this.serverNetwork.close(); + this.serverNetwork.onDisconnect().join(); + } + + @org.junit.Test + public void testPipelining() throws IOException { + var callCount = new Counter(); + var chainedCallCount = new Counter(); + + var server = new RpcSystem<>(this.serverNetwork, new RpcTestUtil.TestPipelineImpl(callCount)); + var client = new Test.TestPipeline.Client(this.client.bootstrap()); + + { + var request = client.getCapRequest(); + request.getParams().setN(234); + request.getParams().setInCap(new RpcTestUtil.TestInterfaceImpl(chainedCallCount)); + + var promise = request.send(); + + var pipelineRequest = promise.getOutBox().getCap().fooRequest(); + pipelineRequest.getParams().setI(321); + + var pipelinePromise = pipelineRequest.send(); + + var pipelineRequest2 = new Test.TestExtends.Client(promise.getOutBox().getCap()).graultRequest(); + var pipelinePromise2 = pipelineRequest2.send(); + + promise = null; + + //Assert.assertEquals(0, chainedCallCount.value()); + + var response = pipelinePromise.join(); + Assert.assertEquals("bar", response.getX().toString()); + + var response2 = pipelinePromise2.join(); + RpcTestUtil.checkTestMessage(response2); + + Assert.assertEquals(1, chainedCallCount.value()); + } + + /* + // disconnect the server + //this.serverSocket.shutdownOutput(); + this.serverNetwork.close(); + this.serverNetwork.onDisconnect().join(); + + { + // Use the now-broken capability. + var request = client.getCapRequest(); + request.getParams().setN(234); + request.getParams().setInCap(new RpcTestUtil.TestInterfaceImpl(chainedCallCount)); + + var promise = request.send(); + + var pipelineRequest = promise.getOutBox().getCap().fooRequest(); + pipelineRequest.getParams().setI(321); + var pipelinePromise = pipelineRequest.send(); + + var pipelineRequest2 = new Test.TestExtends.Client(promise.getOutBox().getCap()).graultRequest(); + var pipelinePromise2 = pipelineRequest2.send(); + + Assert.assertThrows(Exception.class, () -> pipelinePromise.join()); + Assert.assertThrows(Exception.class, () -> pipelinePromise2.join()); + + Assert.assertEquals(3, callCount.value()); + Assert.assertEquals(1, chainedCallCount.value()); + } + + */ + } +/* + @org.junit.Test public void testBasicCleanup() throws ExecutionException, InterruptedException, TimeoutException { var server = new RpcSystem<>(this.serverNetwork, new TestCap0Impl()); var demo = new Demo.TestCap0.Client(this.client.bootstrap()); @@ -145,7 +193,7 @@ public class TwoPartyTest { demo = null; } - @Test + @org.junit.Test public void testShutdown() throws InterruptedException, IOException { var server = new RpcSystem<>(this.serverNetwork, new TestCap0Impl()); var demo = new Demo.TestCap0.Client(this.client.bootstrap()); @@ -153,7 +201,7 @@ public class TwoPartyTest { serverThread.join(); } - @Test + @org.junit.Test public void testCallThrows() throws ExecutionException, InterruptedException { var impl = new Demo.TestCap0.Server() { public CompletableFuture testMethod0(CallContext ctx) { @@ -185,7 +233,7 @@ public class TwoPartyTest { } } - @Test + @org.junit.Test public void testReturnCap() throws ExecutionException, InterruptedException { // send a capability back from the server to the client var capServer = new TestCap0Impl(); @@ -204,5 +252,5 @@ public class TwoPartyTest { var cap2 = results.getResult2(); Assert.assertFalse(cap2.isNull()); } + */ } -*/ diff --git a/runtime/src/main/java/org/capnproto/Capability.java b/runtime/src/main/java/org/capnproto/Capability.java index 821960d..d05c464 100644 --- a/runtime/src/main/java/org/capnproto/Capability.java +++ b/runtime/src/main/java/org/capnproto/Capability.java @@ -150,7 +150,7 @@ public final class Capability { } private final class LocalClient implements ClientHook { - private final CompletableFuture resolveTask; + private CompletableFuture resolveTask; private ClientHook resolved; private boolean blocked = false; private final CapabilityServerSetBase capServerSet; @@ -162,11 +162,16 @@ public final class Capability { LocalClient(CapabilityServerSetBase capServerSet) { Server.this.hook = this; this.capServerSet = capServerSet; + startResolveTask(); + } - var resolver = shortenPath(); - this.resolveTask = resolver != null - ? resolver.thenAccept(client -> this.resolved = client.getHook()) - : null; + private void startResolveTask() { + var resolveTask = shortenPath(); + if (resolveTask != null) { + this.resolveTask = resolveTask.thenAccept(cap -> { + this.resolved = cap.getHook(); + }); + } } @Override @@ -209,6 +214,7 @@ public final class Capability { @Override public CompletableFuture whenMoreResolved() { if (this.resolved != null) { + System.out.println("Local client resolved! " + this.toString()); return CompletableFuture.completedFuture(this.resolved); } else if (this.resolveTask != null) { diff --git a/runtime/src/main/java/org/capnproto/PipelineHook.java b/runtime/src/main/java/org/capnproto/PipelineHook.java index 274117e..32ff23f 100644 --- a/runtime/src/main/java/org/capnproto/PipelineHook.java +++ b/runtime/src/main/java/org/capnproto/PipelineHook.java @@ -5,12 +5,7 @@ public interface PipelineHook extends AutoCloseable { ClientHook getPipelinedCap(PipelineOp[] ops); static PipelineHook newBrokenPipeline(Throwable exc) { - return new PipelineHook() { - @Override - public ClientHook getPipelinedCap(PipelineOp[] ops) { - return Capability.newBrokenCap(exc); - } - }; + return ops -> Capability.newBrokenCap(exc); } @Override