From 054213a0ac7b7c57e2cfcea2575b86c7b4ea2e2a Mon Sep 17 00:00:00 2001 From: Vaci Koblizek Date: Mon, 2 Nov 2020 21:39:58 +0000 Subject: [PATCH] add rpc bootstrap factory --- .../java/org/capnproto/BootstrapFactory.java | 8 + .../src/main/java/org/capnproto/Request.java | 5 +- .../src/main/java/org/capnproto/RpcState.java | 49 +++--- .../main/java/org/capnproto/RpcSystem.java | 104 +++++++++--- .../java/org/capnproto/TwoPartyRpcSystem.java | 12 +- .../org/capnproto/TwoPartyVatNetwork.java | 33 ++-- .../main/java/org/capnproto/VatNetwork.java | 29 ++-- .../java/org/capnproto/CapabilityTest.java | 117 ++++--------- .../test/java/org/capnproto/RpcStateTest.java | 23 ++- .../src/test/java/org/capnproto/TestUtil.java | 154 +++++++++++++++++- runtime/src/test/schema/test.capnp | 96 +++++++++++ 11 files changed, 446 insertions(+), 184 deletions(-) create mode 100644 runtime/src/main/java/org/capnproto/BootstrapFactory.java diff --git a/runtime/src/main/java/org/capnproto/BootstrapFactory.java b/runtime/src/main/java/org/capnproto/BootstrapFactory.java new file mode 100644 index 0000000..c6bd075 --- /dev/null +++ b/runtime/src/main/java/org/capnproto/BootstrapFactory.java @@ -0,0 +1,8 @@ +package org.capnproto; + +public interface BootstrapFactory { + + FromPointerReader getVatIdFactory(); + + Capability.Client createFor(VatId clientId); +} \ No newline at end of file diff --git a/runtime/src/main/java/org/capnproto/Request.java b/runtime/src/main/java/org/capnproto/Request.java index 6bde1ff..e1fb41d 100644 --- a/runtime/src/main/java/org/capnproto/Request.java +++ b/runtime/src/main/java/org/capnproto/Request.java @@ -24,7 +24,8 @@ public interface Request { var hook = new RequestHook() { @Override public RemotePromise send() { - return new RemotePromise<>(CompletableFuture.failedFuture(exc), null); + return new RemotePromise<>(CompletableFuture.failedFuture(exc), + new AnyPointer.Pipeline(PipelineHook.newBrokenPipeline(exc))); } @Override @@ -47,7 +48,7 @@ public interface Request { @Override public Request getTypelessRequest() { - return null; + return new AnyPointer.Request(message.getRoot(AnyPointer.factory), hook); } }; } diff --git a/runtime/src/main/java/org/capnproto/RpcState.java b/runtime/src/main/java/org/capnproto/RpcState.java index af704f2..580bdba 100644 --- a/runtime/src/main/java/org/capnproto/RpcState.java +++ b/runtime/src/main/java/org/capnproto/RpcState.java @@ -1,6 +1,7 @@ package org.capnproto; import java.io.IOException; +import java.lang.ref.Reference; import java.lang.ref.ReferenceQueue; import java.lang.ref.WeakReference; import java.util.*; @@ -9,7 +10,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.function.Consumer; -final class RpcState { +final class RpcState { private static int messageSizeHint() { return 1 + RpcProtocol.Message.factory.structSize().total(); @@ -19,12 +20,12 @@ final class RpcState { return RpcProtocol.Exception.factory.structSize().total() + exc.getMessage().length(); } - private static int MESSAGE_TARGET_SIZE_HINT + private static final int MESSAGE_TARGET_SIZE_HINT = RpcProtocol.MessageTarget.factory.structSize().total() + RpcProtocol.PromisedAnswer.factory.structSize().total() + 16; - private static int CAP_DESCRIPTOR_SIZE_HINT + private static final int CAP_DESCRIPTOR_SIZE_HINT = RpcProtocol.CapDescriptor.factory.structSize().total() + RpcProtocol.PromisedAnswer.factory.structSize().total(); @@ -63,7 +64,7 @@ final class RpcState { } } - private static final class QuestionRef extends WeakReference { + private final class QuestionRef extends WeakReference { private final QuestionDisposer disposer; @@ -77,7 +78,7 @@ final class RpcState { } } - private final class Question { + private class Question { CompletableFuture response = new CompletableFuture<>(); int[] paramExports = new int[0]; @@ -146,8 +147,8 @@ final class RpcState { public Iterator iterator() { return this.slots.values() .stream() - .map(ref -> ref.get()) - .filter(question -> question != null) + .map(Reference::get) + .filter(Objects::nonNull) .iterator(); } @@ -160,7 +161,7 @@ final class RpcState { } } - static final class Answer { + final class Answer { final int answerId; boolean active = false; PipelineHook pipeline; @@ -228,28 +229,14 @@ final class RpcState { } } - private final ExportTable exports = new ExportTable() { + private final ExportTable exports = new ExportTable<>() { @Override Export newExportable(int id) { return new Export(id); } }; - /* - private final ExportTable questions = new ExportTable<>() { - @Override - QuestionRef newExportable(int id) { - return new QuestionRef(new Question(id)); - } - }; -*/ private final QuestionExportTable questions = new QuestionExportTable(); - /*{ - @Override - Question newExportable(int id) { - return new Question(id); - } -*/ private final ImportTable answers = new ImportTable<>() { @Override @@ -273,8 +260,8 @@ final class RpcState { }; private final Map exportsByCap = new HashMap<>(); - private final Capability.Client bootstrapInterface; - private final VatNetwork.Connection connection; + private final BootstrapFactory bootstrapFactory; + private final VatNetwork.Connection connection; private final CompletableFuture onDisconnect; private Throwable disconnected = null; private CompletableFuture messageReady = CompletableFuture.completedFuture(null); @@ -282,10 +269,10 @@ final class RpcState { private final ReferenceQueue questionRefs = new ReferenceQueue<>(); private final ReferenceQueue importRefs = new ReferenceQueue<>(); - RpcState(Capability.Client bootstrapInterface, - VatNetwork.Connection connection, + RpcState(BootstrapFactory bootstrapFactory, + VatNetwork.Connection connection, CompletableFuture onDisconnect) { - this.bootstrapInterface = bootstrapInterface; + this.bootstrapFactory = bootstrapFactory; this.connection = connection; this.onDisconnect = onDisconnect; this.messageLoop = this.doMessageLoop(); @@ -445,6 +432,7 @@ final class RpcState { private void handleMessage(IncomingRpcMessage message) throws RpcException { var reader = message.getBody().getAs(RpcProtocol.Message.factory); + //System.out.println(reader.which()); switch (reader.which()) { case UNIMPLEMENTED: handleUnimplemented(reader.getUnimplemented()); @@ -548,7 +536,8 @@ final class RpcState { var payload = ret.initResults(); var content = payload.getContent().imbue(capTable); - content.setAsCap(bootstrapInterface); + var cap = this.bootstrapFactory.createFor(connection.getPeerVatId()); + content.setAsCap(cap); var caps = capTable.getTable(); var capHook = caps.length != 0 ? caps[0] @@ -1193,7 +1182,7 @@ final class RpcState { AnyPointer.Builder getResultsBuilder(); } - static class RpcResponseImpl implements RpcResponse { + class RpcResponseImpl implements RpcResponse { private final Question question; private final IncomingRpcMessage message; private final AnyPointer.Reader results; diff --git a/runtime/src/main/java/org/capnproto/RpcSystem.java b/runtime/src/main/java/org/capnproto/RpcSystem.java index 4e31205..73cc7d3 100644 --- a/runtime/src/main/java/org/capnproto/RpcSystem.java +++ b/runtime/src/main/java/org/capnproto/RpcSystem.java @@ -4,46 +4,81 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.CompletableFuture; -public abstract class RpcSystem { +public class RpcSystem { - final VatNetwork network; - final Capability.Client bootstrapInterface; - final Map connections = new HashMap<>(); - final CompletableFuture messageLoop; - final CompletableFuture acceptLoop; + private final VatNetwork network; + private final BootstrapFactory bootstrapFactory; + private final Map, RpcState> connections = new HashMap<>(); + private final CompletableFuture messageLoop; + private final CompletableFuture acceptLoop; - public RpcSystem(VatNetwork network, Capability.Client bootstrapInterface) { + public RpcSystem(VatNetwork network) { this.network = network; - this.bootstrapInterface = bootstrapInterface; + this.bootstrapFactory = null; + this.acceptLoop = new CompletableFuture<>(); + this.messageLoop = doMessageLoop(); + } + + public VatNetwork getNetwork() { + return this.network; + } + + public RpcSystem(VatNetwork network, + Capability.Client bootstrapInterface) { + this(network, new BootstrapFactory() { + + @Override + public FromPointerReader getVatIdFactory() { + return this.getVatIdFactory(); + } + + @Override + public Capability.Client createFor(VatId clientId) { + return bootstrapInterface; + } + }); + } + + public RpcSystem(VatNetwork network, + BootstrapFactory bootstrapFactory) { + this.network = network; + this.bootstrapFactory = bootstrapFactory; this.acceptLoop = doAcceptLoop(); this.messageLoop = doMessageLoop(); } - public CompletableFuture getMessageLoop() { - return this.messageLoop; + public Capability.Client bootstrap(VatId vatId) { + var connection = this.getNetwork().connect(vatId); + if (connection != null) { + var state = getConnectionState(connection); + var hook = state.restore(); + return new Capability.Client(hook); + } + else if (this.bootstrapFactory != null) { + return this.bootstrapFactory.createFor(vatId); + } + else { + return new Capability.Client(Capability.newBrokenCap("No bootstrap interface available")); + } } - private CompletableFuture getAcceptLoop() { - return this.acceptLoop; - } + RpcState getConnectionState(VatNetwork.Connection connection) { - public void accept(VatNetwork.Connection connection) { - getConnectionState(connection); - } - - RpcState getConnectionState(VatNetwork.Connection connection) { - - var onDisconnect = new CompletableFuture() + var onDisconnect = new CompletableFuture>() .thenAccept(lostConnection -> { this.connections.remove(lostConnection); }); return connections.computeIfAbsent(connection, key -> - new RpcState(bootstrapInterface, connection, onDisconnect)); + new RpcState(this.bootstrapFactory, connection, onDisconnect)); + } + + public void accept(VatNetwork.Connection connection) { + getConnectionState(connection); } private CompletableFuture doAcceptLoop() { - return this.network.baseAccept().thenCompose(connection -> { + return this.getNetwork().baseAccept().thenCompose(connection -> { this.accept(connection); return this.doAcceptLoop(); }); @@ -56,4 +91,29 @@ public abstract class RpcSystem { } return accept.thenCompose(x -> this.doMessageLoop()); } + + public CompletableFuture getMessageLoop() { + return this.messageLoop; + } + + private CompletableFuture getAcceptLoop() { + return this.acceptLoop; + } + + public static + RpcSystem makeRpcClient(VatNetwork network) { + return new RpcSystem<>(network); + } + + public static + RpcSystem makeRpcServer(VatNetwork network, + BootstrapFactory bootstrapFactory) { + return new RpcSystem<>(network, bootstrapFactory); + } + + public static + RpcSystem makeRpcServer(VatNetwork network, + Capability.Client bootstrapInterface) { + return new RpcSystem<>(network, bootstrapInterface); + } } diff --git a/runtime/src/main/java/org/capnproto/TwoPartyRpcSystem.java b/runtime/src/main/java/org/capnproto/TwoPartyRpcSystem.java index 1add9f9..66e7efe 100644 --- a/runtime/src/main/java/org/capnproto/TwoPartyRpcSystem.java +++ b/runtime/src/main/java/org/capnproto/TwoPartyRpcSystem.java @@ -3,18 +3,20 @@ package org.capnproto; public class TwoPartyRpcSystem extends RpcSystem { + private TwoPartyVatNetwork network; + public TwoPartyRpcSystem(TwoPartyVatNetwork network, Capability.Client bootstrapInterface) { super(network, bootstrapInterface); + this.network = network; } public TwoPartyRpcSystem(TwoPartyVatNetwork network, Capability.Server bootstrapInterface) { super(network, new Capability.Client(bootstrapInterface)); + this.network = network; } - public Capability.Client bootstrap(RpcTwoPartyProtocol.VatId.Reader vatId) { - var connection = this.network.baseConnect(vatId); - var state = getConnectionState(connection); - var hook = state.restore(); - return new Capability.Client(hook); + @Override + public VatNetwork getNetwork() { + return this.network; } } diff --git a/runtime/src/main/java/org/capnproto/TwoPartyVatNetwork.java b/runtime/src/main/java/org/capnproto/TwoPartyVatNetwork.java index 6196ade..3e48db7 100644 --- a/runtime/src/main/java/org/capnproto/TwoPartyVatNetwork.java +++ b/runtime/src/main/java/org/capnproto/TwoPartyVatNetwork.java @@ -4,9 +4,15 @@ import java.nio.channels.AsynchronousSocketChannel; import java.util.List; import java.util.concurrent.CompletableFuture; + public class TwoPartyVatNetwork implements VatNetwork, - VatNetwork.Connection { + VatNetwork.Connection { + + @Override + public CompletableFuture> baseAccept() { + return this.accept(); + } public interface MessageTap { void incoming(IncomingRpcMessage message, RpcTwoPartyProtocol.Side side); @@ -33,25 +39,22 @@ public class TwoPartyVatNetwork return side; } - public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() { - return peerVatId.getRoot(RpcTwoPartyProtocol.VatId.factory).asReader(); - } - public void setTap(MessageTap tap) { this.tap = tap; } - public VatNetwork.Connection asConnection() { + public Connection asConnection() { return this; } - private Connection connect(RpcTwoPartyProtocol.VatId.Reader vatId) { + @Override + public Connection connect(RpcTwoPartyProtocol.VatId.Reader vatId) { return vatId.getSide() != side ? this.asConnection() : null; } - private CompletableFuture accept() { + public CompletableFuture> accept() { if (side == RpcTwoPartyProtocol.Side.SERVER & !accepted) { accepted = true; return CompletableFuture.completedFuture(this.asConnection()); @@ -62,6 +65,10 @@ public class TwoPartyVatNetwork } } + public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() { + return this.peerVatId.getRoot(RpcTwoPartyProtocol.VatId.factory).asReader(); + } + @Override public OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize) { return new OutgoingMessage(firstSegmentWordSize); @@ -111,16 +118,6 @@ public class TwoPartyVatNetwork }); } - @Override - public Connection baseConnect(RpcTwoPartyProtocol.VatId.Reader hostId) { - return this.connect(hostId); - } - - @Override - public CompletableFuture baseAccept() { - return this.accept(); - } - final class OutgoingMessage implements OutgoingRpcMessage { private final MessageBuilder message; diff --git a/runtime/src/main/java/org/capnproto/VatNetwork.java b/runtime/src/main/java/org/capnproto/VatNetwork.java index 3f63ad6..1286c71 100644 --- a/runtime/src/main/java/org/capnproto/VatNetwork.java +++ b/runtime/src/main/java/org/capnproto/VatNetwork.java @@ -2,18 +2,23 @@ package org.capnproto; import java.util.concurrent.CompletableFuture; -public interface VatNetwork { +public interface VatNetwork +{ + interface Connection { + default OutgoingRpcMessage newOutgoingMessage() { + return newOutgoingMessage(0); + } + OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize); + CompletableFuture receiveIncomingMessage(); + CompletableFuture onDisconnect(); + CompletableFuture shutdown(); + VatId getPeerVatId(); + } - interface Connection { - default OutgoingRpcMessage newOutgoingMessage() { - return newOutgoingMessage(0); - } - OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize); - CompletableFuture receiveIncomingMessage(); - CompletableFuture onDisconnect(); - CompletableFuture shutdown(); - } + CompletableFuture> baseAccept(); - Connection baseConnect(VatId hostId); - CompletableFuture baseAccept(); + //FromPointerReader getVatIdFactory(); + + Connection connect(VatId hostId); } + diff --git a/runtime/src/test/java/org/capnproto/CapabilityTest.java b/runtime/src/test/java/org/capnproto/CapabilityTest.java index 49ba7bc..3397be4 100644 --- a/runtime/src/test/java/org/capnproto/CapabilityTest.java +++ b/runtime/src/test/java/org/capnproto/CapabilityTest.java @@ -21,8 +21,10 @@ package org.capnproto; +import org.capnproto.test.Test; + import org.junit.Assert; -import org.junit.Test; +import org.junit.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -33,37 +35,7 @@ class Counter { int value() { return count; } } -class TestInterfaceImpl extends org.capnproto.test.Test.TestInterface.Server { - - final Counter counter; - - TestInterfaceImpl(Counter counter) { - this.counter = counter; - } - - @Override - protected CompletableFuture foo(CallContext ctx) { - this.counter.inc(); - var params = ctx.getParams(); - var result = ctx.getResults(); - Assert.assertEquals(123, params.getI()); - Assert.assertTrue(params.getJ()); - result.setX("foo"); - return READY_NOW; - } - - @Override - protected CompletableFuture baz(CallContext context) { - this.counter.inc(); - var params = context.getParams(); - TestUtil.checkTestMessage(params.getS()); - context.releaseParams(); - Assert.assertThrows(RpcException.class, () -> context.getParams()); - return READY_NOW; - } -} - -class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server { +class TestExtendsImpl extends Test.TestExtends2.Server { final Counter counter; @@ -72,7 +44,7 @@ class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server { } @Override - protected CompletableFuture foo(CallContext context) { + protected CompletableFuture foo(CallContext context) { counter.inc(); var params = context.getParams(); var result = context.getResults(); @@ -83,7 +55,7 @@ class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server { } @Override - protected CompletableFuture grault(CallContext context) { + protected CompletableFuture grault(CallContext context) { counter.inc(); context.releaseParams(); TestUtil.initTestMessage(context.getResults()); @@ -91,50 +63,25 @@ class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server { } } -class TestPipelineImpl extends org.capnproto.test.Test.TestPipeline.Server { +class TestCallOrderImpl extends Test.TestCallOrder.Server { - final Counter counter; - - TestPipelineImpl(Counter counter) { - this.counter = counter; - } + int count = 0; @Override - protected CompletableFuture getCap(CallContext ctx) { - this.counter.inc(); - var params = ctx.getParams(); - Assert.assertEquals(234, params.getN()); - var cap = params.getInCap(); - ctx.releaseParams(); - - var request = cap.fooRequest(); - var fooParams = request.getParams(); - fooParams.setI(123); - fooParams.setJ(true); - - return request.send().thenAccept(response -> { - Assert.assertEquals("foo", response.getX().toString()); - var result = ctx.getResults(); - result.setS("bar"); - - org.capnproto.test.Test.TestExtends.Server server = new TestExtendsImpl(this.counter); - result.initOutBox().setCap(server); - }); - } - - @Override - protected CompletableFuture getAnyCap(CallContext context) { - return super.getAnyCap(context); + protected CompletableFuture getCallSequence(CallContext context) { + var result = context.getResults(); + result.setN(this.count++); + return READY_NOW; } } public class CapabilityTest { - @Test + @org.junit.Test public void testBasic() { var callCount = new Counter(); - var client = new org.capnproto.test.Test.TestInterface.Client( - new TestInterfaceImpl(callCount)); + var client = new Test.TestInterface.Client( + new TestUtil.TestInterfaceImpl(callCount)); var request1 = client.fooRequest(); request1.getParams().setI(123); @@ -155,15 +102,15 @@ public class CapabilityTest { }); } - @Test + @org.junit.Test public void testInheritance() throws ExecutionException, InterruptedException { var callCount = new Counter(); - var client1 = new org.capnproto.test.Test.TestExtends.Client( + var client1 = new Test.TestExtends.Client( new TestExtendsImpl(callCount)); - org.capnproto.test.Test.TestInterface.Client client2 = client1; - var client = (org.capnproto.test.Test.TestExtends.Client)client2; + Test.TestInterface.Client client2 = client1; + var client = (Test.TestExtends.Client)client2; var request1 = client.fooRequest(); request1.getParams().setI(321); @@ -183,26 +130,26 @@ public class CapabilityTest { Assert.assertEquals(2, callCount.value()); } - @Test + @org.junit.Test public void testPipelining() throws ExecutionException, InterruptedException { var callCount = new Counter(); var chainedCallCount = new Counter(); - var client = new org.capnproto.test.Test.TestPipeline.Client( - new TestPipelineImpl(callCount)); + var client = new Test.TestPipeline.Client( + new TestUtil.TestPipelineImpl(callCount)); var request = client.getCapRequest(); var params = request.getParams(); params.setN(234); - params.setInCap(new org.capnproto.test.Test.TestInterface.Client( - new TestInterfaceImpl(chainedCallCount))); + params.setInCap(new Test.TestInterface.Client( + new TestUtil.TestInterfaceImpl(chainedCallCount))); var promise = request.send(); var outbox = promise.getOutBox(); var pipelineRequest = outbox.getCap().fooRequest(); pipelineRequest.getParams().setI(321); var pipelinePromise = pipelineRequest.send(); - var pipelineRequest2 = new org.capnproto.test.Test.TestExtends.Client(promise.getOutBox().getCap()).graultRequest(); + var pipelineRequest2 = new Test.TestExtends.Client(promise.getOutBox().getCap()).graultRequest(); var pipelinePromise2 = pipelineRequest2.send(); // Hmm, we have no means to defer the evaluation of callInternal. The best we can do is @@ -219,7 +166,7 @@ public class CapabilityTest { Assert.assertEquals(1, chainedCallCount.value()); } - class TestThisCap extends org.capnproto.test.Test.TestInterface.Server { + class TestThisCap extends Test.TestInterface.Server { Counter counter; @@ -227,29 +174,29 @@ public class CapabilityTest { this.counter = counter; } - org.capnproto.test.Test.TestInterface.Client getSelf() { + Test.TestInterface.Client getSelf() { return this.thisCap(); } @Override - protected CompletableFuture bar(CallContext context) { + protected CompletableFuture bar(CallContext context) { this.counter.inc(); return READY_NOW; } } - @Test + @org.junit.Test public void testGenerics() { - var factory = org.capnproto.test.Test.TestGenerics.newFactory(org.capnproto.test.Test.TestAllTypes.factory, AnyPointer.factory); + var factory = Test.TestGenerics.newFactory(Test.TestAllTypes.factory, AnyPointer.factory); } - @Test + @org.junit.Test public void thisCap() { var callCount = new Counter(); var server = new TestThisCap(callCount); - var client = new org.capnproto.test.Test.TestInterface.Client(server); + var client = new Test.TestInterface.Client(server); client.barRequest().send().join(); Assert.assertEquals(1, callCount.value()); diff --git a/runtime/src/test/java/org/capnproto/RpcStateTest.java b/runtime/src/test/java/org/capnproto/RpcStateTest.java index 7f5c6d5..173bdc4 100644 --- a/runtime/src/test/java/org/capnproto/RpcStateTest.java +++ b/runtime/src/test/java/org/capnproto/RpcStateTest.java @@ -8,8 +8,6 @@ import org.junit.Test; import java.util.ArrayDeque; import java.util.Queue; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; public class RpcStateTest { @@ -23,7 +21,7 @@ public class RpcStateTest { } } - class TestConnection implements VatNetwork.Connection { + class TestConnection implements VatNetwork.Connection { private CompletableFuture nextIncomingMessage = new CompletableFuture<>(); private final CompletableFuture disconnect = new CompletableFuture<>(); @@ -69,6 +67,11 @@ public class RpcStateTest { this.disconnect.complete(null); return this.disconnect.copy(); } + + @Override + public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() { + return null; + } } TestConnection connection; @@ -80,7 +83,19 @@ public class RpcStateTest { public void setUp() throws Exception { this.connection = new TestConnection(); this.bootstrapInterface = new Capability.Client(Capability.newNullCap()); - this.rpc = new RpcState(bootstrapInterface, connection, connection.disconnect); + var bootstrapFactory = new BootstrapFactory() { + @Override + public FromPointerReader getVatIdFactory() { + return RpcTwoPartyProtocol.VatId.factory; + } + + @Override + public Capability.Client createFor(RpcTwoPartyProtocol.VatId.Reader clientId) { + return bootstrapInterface; + } + }; + + this.rpc = new RpcState(bootstrapFactory, connection, connection.disconnect); } @After diff --git a/runtime/src/test/java/org/capnproto/TestUtil.java b/runtime/src/test/java/org/capnproto/TestUtil.java index 75f9727..cfda3f2 100644 --- a/runtime/src/test/java/org/capnproto/TestUtil.java +++ b/runtime/src/test/java/org/capnproto/TestUtil.java @@ -1,9 +1,12 @@ package org.capnproto; +import org.capnproto.test.Test; import org.junit.Assert; +import java.util.concurrent.CompletableFuture; + class TestUtil { - static void initTestMessage(org.capnproto.test.Test.TestAllTypes.Builder builder) { + static void initTestMessage(Test.TestAllTypes.Builder builder) { builder.setVoidField(Void.VOID); builder.setBoolField(true); builder.setInt8Field((byte) -123); @@ -12,26 +15,165 @@ class TestUtil { builder.setInt64Field(-123456789012345L); builder.setUInt8Field((byte) 234); builder.setUInt16Field((short) 45678); - builder.setUInt32Field((int) 3456789012l); + builder.setUInt32Field((int) 3456789012L); builder.setUInt64Field(1234567890123456789L); builder.setFloat32Field(1234.5f); builder.setFloat64Field(-123e45); builder.setTextField("foo"); } - static void checkTestMessage(org.capnproto.test.Test.TestAllTypes.Reader reader) { + static void checkTestMessage(Test.TestAllTypes.Reader reader) { Assert.assertEquals(Void.VOID, reader.getVoidField()); Assert.assertTrue(reader.getBoolField()); Assert.assertEquals((byte)-123, reader.getInt8Field()); Assert.assertEquals((short)-12345, reader.getInt16Field()); Assert.assertEquals(-12345678, reader.getInt32Field()); - Assert.assertEquals(-123456789012345l, reader.getInt64Field()); + Assert.assertEquals(-123456789012345L, reader.getInt64Field()); Assert.assertEquals((byte)234, reader.getUInt8Field()); Assert.assertEquals((short)45678, reader.getUInt16Field()); - Assert.assertEquals((int)3456789012l, reader.getUInt32Field()); - Assert.assertEquals(1234567890123456789l, reader.getUInt64Field()); + Assert.assertEquals((int) 3456789012L, reader.getUInt32Field()); + Assert.assertEquals(1234567890123456789L, reader.getUInt64Field()); Assert.assertEquals(null, 1234.5f, reader.getFloat32Field(), 0.1f); Assert.assertEquals(null, -123e45, reader.getFloat64Field(), 0.1f); Assert.assertEquals("foo", reader.getTextField().toString()); } + + static class TestInterfaceImpl extends Test.TestInterface.Server { + + final Counter counter; + + TestInterfaceImpl(Counter counter) { + this.counter = counter; + } + + @Override + protected CompletableFuture foo(CallContext ctx) { + this.counter.inc(); + var params = ctx.getParams(); + var result = ctx.getResults(); + Assert.assertEquals(123, params.getI()); + Assert.assertTrue(params.getJ()); + result.setX("foo"); + return READY_NOW; + } + + @Override + protected CompletableFuture baz(CallContext context) { + this.counter.inc(); + var params = context.getParams(); + checkTestMessage(params.getS()); + context.releaseParams(); + return READY_NOW; + } + } + + static class TestTailCallerImpl extends Test.TestTailCaller.Server { + + private final Counter count; + + public TestTailCallerImpl(Counter count) { + this.count = count; + } + + @Override + protected CompletableFuture foo(CallContext context) { + this.count.inc(); + var params = context.getParams(); + var tailRequest = params.getCallee().fooRequest(); + tailRequest.getParams().setI(params.getI()); + tailRequest.getParams().setT("from TestTailCaller"); + return context.tailCall(tailRequest); + } + + public int getCount() { + return this.count.value(); + } + } + + static class TestMoreStuffImpl extends Test.TestMoreStuff.Server { + + final Counter callCount; + final Counter handleCount; + + public TestMoreStuffImpl(Counter callCount, Counter handleCount) { + this.callCount = callCount; + this.handleCount = handleCount; + } + } + + static class TestTailCalleeImpl extends Test.TestTailCallee.Server { + + private final Counter count; + + public TestTailCalleeImpl(Counter count) { + this.count = count; + } + + @Override + protected CompletableFuture foo(CallContext context) { + this.count.inc(); + + var params = context.getParams(); + var results = context.getResults(); + + results.setI(params.getI()); + results.setT(params.getT()); + results.setC(new TestCallOrderImpl()); + return READY_NOW; + } + } + + static class TestPipelineImpl extends Test.TestPipeline.Server { + + final Counter callCount; + + TestPipelineImpl(Counter callCount) { + this.callCount = callCount; + } + + @Override + protected CompletableFuture getCap(CallContext ctx) { + this.callCount.inc(); + var params = ctx.getParams(); + Assert.assertEquals(234, params.getN()); + var cap = params.getInCap(); + ctx.releaseParams(); + + var request = cap.fooRequest(); + var fooParams = request.getParams(); + fooParams.setI(123); + fooParams.setJ(true); + + return request.send().thenAccept(response -> { + Assert.assertEquals("foo", response.getX().toString()); + var result = ctx.getResults(); + result.setS("bar"); + + Test.TestExtends.Server server = new TestExtendsImpl(this.callCount); + result.initOutBox().setCap(server); + }); + } + + @Override + protected CompletableFuture getAnyCap(CallContext context) { + this.callCount.inc(); + var params = context.getParams(); + Assert.assertEquals(234, params.getN()); + + var cap = params.getInCap(); + context.releaseParams(); + + var request = new Test.TestInterface.Client(cap).fooRequest(); + request.getParams().setI(123); + request.getParams().setJ(true); + + return request.send().thenAccept(response -> { + Assert.assertEquals("foo", response.getX().toString()); + + var result = context.getResults(); + result.setS("bar"); + result.initOutBox().setCap(new TestExtendsImpl(callCount)); + }); + } + } } diff --git a/runtime/src/test/schema/test.capnp b/runtime/src/test/schema/test.capnp index cf824c8..ea500f6 100644 --- a/runtime/src/test/schema/test.capnp +++ b/runtime/src/test/schema/test.capnp @@ -21,6 +21,32 @@ struct TestAllTypes { dataField @13 : Data; } +struct TestSturdyRef { + hostId @0 :TestSturdyRefHostId; + objectId @1 :AnyPointer; +} + +struct TestSturdyRefHostId { + host @0 :Text; +} + +struct TestSturdyRefObjectId { + tag @0 :Tag; + enum Tag { + testInterface @0; + testExtends @1; + testPipeline @2; + testTailCallee @3; + testTailCaller @4; + testMoreStuff @5; + } +} + +struct TestProvisionId {} +struct TestRecipientId {} +struct TestThirdPartyCapId {} +struct TestJoinResult {} + interface TestInterface { foo @0 (i :UInt32, j :Bool) -> (x :Text); bar @1 () -> (); @@ -48,6 +74,76 @@ interface TestPipeline { } } +interface TestCallOrder { + getCallSequence @0 (expected: UInt32) -> (n: UInt32); + # First call returns 0, next returns 1, ... + # + # The input `expected` is ignored but useful for disambiguating debug logs. +} + +interface TestTailCallee { + struct TailResult { + i @0 :UInt32; + t @1 :Text; + c @2 :TestCallOrder; + } + + foo @0 (i :Int32, t :Text) -> TailResult; +} + +interface TestTailCaller { + foo @0 (i :Int32, callee :TestTailCallee) -> TestTailCallee.TailResult; +} + +interface TestHandle {} + +interface TestMoreStuff extends(TestCallOrder) { + # Catch-all type that contains lots of testing methods. + + callFoo @0 (cap :TestInterface) -> (s: Text); + # Call `cap.foo()`, check the result, and return "bar". + + callFooWhenResolved @1 (cap :TestInterface) -> (s: Text); + # Like callFoo but waits for `cap` to resolve first. + + neverReturn @2 (cap :TestInterface) -> (capCopy :TestInterface); + # Doesn't return. You should cancel it. + + hold @3 (cap :TestInterface) -> (); + # Returns immediately but holds on to the capability. + + callHeld @4 () -> (s: Text); + # Calls the capability previously held using `hold` (and keeps holding it). + + getHeld @5 () -> (cap :TestInterface); + # Returns the capability previously held using `hold` (and keeps holding it). + + echo @6 (cap :TestCallOrder) -> (cap :TestCallOrder); + # Just returns the input cap. + + expectCancel @7 (cap :TestInterface) -> (); + # evalLater()-loops forever, holding `cap`. Must be canceled. + + methodWithDefaults @8 (a :Text, b :UInt32 = 123, c :Text = "foo") -> (d :Text, e :Text = "bar"); + + methodWithNullDefault @12 (a :Text, b :TestInterface = null); + + getHandle @9 () -> (handle :TestHandle); + # Get a new handle. Tests have an out-of-band way to check the current number of live handles, so + # this can be used to test garbage collection. + + getNull @10 () -> (nullCap :TestMoreStuff); + # Always returns a null capability. + + getEnormousString @11 () -> (str :Text); + # Attempts to return an 100MB string. Should always fail. + + writeToFd @13 (fdCap1 :TestInterface, fdCap2 :TestInterface) + -> (fdCap3 :TestInterface, secondFdPresent :Bool); + # Expects fdCap1 and fdCap2 wrap socket file descriptors. Writes "foo" to the first and "bar" to + # the second. Also creates a socketpair, writes "baz" to one end, and returns the other end. +} + struct TestGenerics(Foo, Bar) { foo @0 :Foo; rev @1 :TestGenerics(Bar, Foo);