diff --git a/runtime-rpc/src/test/java/org/capnproto/CapabilityTest.java b/runtime-rpc/src/test/java/org/capnproto/CapabilityTest.java index 7354a10..2296746 100644 --- a/runtime-rpc/src/test/java/org/capnproto/CapabilityTest.java +++ b/runtime-rpc/src/test/java/org/capnproto/CapabilityTest.java @@ -21,16 +21,13 @@ package org.capnproto; // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -import org.capnproto.AnyPointer; -import org.capnproto.CallContext; -import org.capnproto.Capability; -import org.capnproto.RpcException; import org.capnproto.rpctest.Test; import org.junit.Assert; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; class Counter { private int count = 0; @@ -291,4 +288,70 @@ public final class CapabilityTest { Assert.assertEquals(579, result.getTotalI()); Assert.assertEquals(321, result.getTotalJ()); } + + @org.junit.Test + public void testCapabilityServerSet() { + var set1 = new Capability.CapabilityServerSet(); + var set2 = new Capability.CapabilityServerSet(); + + var callCount = new Counter(); + var clientStandalone = new Test.TestInterface.Client(new RpcTestUtil.TestInterfaceImpl(callCount)); + var clientNull = new Test.TestInterface.Client(); + + var ownServer1 = new RpcTestUtil.TestInterfaceImpl(callCount); + var server1 = ownServer1; + var client1 = set1.add(Test.TestInterface.factory, ownServer1); + + var ownServer2 = new RpcTestUtil.TestInterfaceImpl(callCount); + var server2 = ownServer2; + var client2 = set2.add(Test.TestInterface.factory, ownServer2); + + // Getting the local server using the correct set works. + Assert.assertEquals(server1, set1.getLocalServer(client1).join()); + Assert.assertEquals(server2, set2.getLocalServer(client2).join()); + + // Getting the local server using the wrong set doesn't work. + Assert.assertNull(set1.getLocalServer(client2).join()); + Assert.assertNull(set2.getLocalServer(client1).join()); + Assert.assertNull(set1.getLocalServer(clientStandalone).join()); + Assert.assertNull(set1.getLocalServer(clientNull).join()); + + var promise = new CompletableFuture(); + var clientPromise = new Test.TestInterface.Client(promise); + + var errorPromise = new CompletableFuture(); + var clientErrorPromise = new Test.TestInterface.Client(errorPromise); + + var resolved1 = new AtomicBoolean(false); + var resolved2 = new AtomicBoolean(false); + var resolved3 = new AtomicBoolean(false); + + var promise1 = set1.getLocalServer(clientPromise).thenAccept(server -> { + resolved1.set(true); + Assert.assertEquals(server1, server); + }); + + var promise2 = set2.getLocalServer(clientPromise).thenAccept(server -> { + resolved2.set(true); + Assert.assertNull(server); + }); + + var promise3 = set1.getLocalServer(clientErrorPromise).whenComplete((server, exc) -> { + resolved3.set(true); + Assert.assertNull(server); + Assert.assertNotNull(exc); + Assert.assertTrue(exc.getCause() instanceof RpcException); + }); + + Assert.assertFalse(resolved1.get()); + Assert.assertFalse(resolved2.get()); + Assert.assertFalse(resolved3.get()); + + promise.complete(client1); + errorPromise.completeExceptionally(RpcException.failed("foo")); + + Assert.assertTrue(resolved1.get()); + Assert.assertTrue(resolved2.get()); + Assert.assertTrue(resolved3.get()); + } } diff --git a/runtime/src/main/java/org/capnproto/Capability.java b/runtime/src/main/java/org/capnproto/Capability.java index 14e00fa..5dd8084 100644 --- a/runtime/src/main/java/org/capnproto/Capability.java +++ b/runtime/src/main/java/org/capnproto/Capability.java @@ -136,7 +136,7 @@ public final class Capability { this(other.hook); } - public Client(Server server) { + public Client(T server) { this(makeLocalClient(server)); } @@ -157,7 +157,7 @@ public final class Capability { return this.hook; } - private static ClientHook makeLocalClient(Server server) { + private static ClientHook makeLocalClient(T server) { return server.makeLocalClient(); } } @@ -169,27 +169,27 @@ public final class Capability { private ClientHook hook; ClientHook makeLocalClient() { - return new LocalClient(); + return new LocalClient<>(); } - ClientHook makeLocalClient(CapabilityServerSetBase capServerSet) { - return new LocalClient(capServerSet); + ClientHook makeLocalClient(CapabilityServerSet capServerSet) { + return new LocalClient<>(capServerSet); } - private final class LocalClient implements ClientHook { + private final class LocalClient implements ClientHook { private CompletableFuture resolveTask; private ClientHook resolved; private boolean blocked = false; private Throwable brokenException; private final Queue blockedCalls = new ArrayDeque<>(); - private final CapabilityServerSetBase capServerSet; + private final CapabilityServerSet capServerSet; LocalClient() { this(null); } - LocalClient(CapabilityServerSetBase capServerSet) { + LocalClient(CapabilityServerSet capServerSet) { Server.this.hook = this; this.capServerSet = capServerSet; var resolveTask = shortenPath(); @@ -311,14 +311,17 @@ public final class Capability { } } - public CompletableFuture getLocalServer(CapabilityServerSetBase capServerSet) { + public CompletableFuture getLocalServer(CapabilityServerSet capServerSet) { if (this.capServerSet == capServerSet) { if (this.blocked) { - var promise = new CompletableFuture(); - this.blockedCalls.add(() -> promise.complete(Server.this)); + var promise = new CompletableFuture(); + var server = (T)Server.this; + this.blockedCalls.add(() -> promise.complete(server)); return promise; } - return CompletableFuture.completedFuture(Server.this); + + var server = (T)Server.this; + return CompletableFuture.completedFuture(server); } return null; } @@ -712,13 +715,13 @@ public final class Capability { } } - static class CapabilityServerSetBase { + public static final class CapabilityServerSet { - ClientHook addInternal(Server server) { + ClientHook addInternal(T server) { return server.makeLocalClient(this); } - CompletableFuture getLocalServerInternal(ClientHook hook) { + CompletableFuture getLocalServerInternal(ClientHook hook) { for (;;) { var next = hook.getResolved(); if (next != null) { @@ -728,8 +731,9 @@ public final class Capability { break; } } + if (hook.getBrand() == Server.BRAND) { - var promise = ((Server.LocalClient)hook).getLocalServer(this); + var promise = ((Server.LocalClient)hook).getLocalServer(this); if (promise != null) { return promise; } @@ -744,12 +748,9 @@ public final class Capability { } else { // Cap is settled, so it definitely will never resolve to a member of this set. - return CompletableFuture.completedFuture(null); + return CompletableFuture.completedFuture(null); } } - } - - public static final class CapabilityServerSet extends CapabilityServerSetBase { /** * Create a new capability Client for the given Server and also add this server to the set.