add rpc bootstrap factory

This commit is contained in:
Vaci Koblizek 2020-11-02 21:39:58 +00:00
parent 9d023f0449
commit 054213a0ac
11 changed files with 446 additions and 184 deletions

View file

@ -0,0 +1,8 @@
package org.capnproto;
public interface BootstrapFactory<VatId> {
FromPointerReader<VatId> getVatIdFactory();
Capability.Client createFor(VatId clientId);
}

View file

@ -24,7 +24,8 @@ public interface Request<Params> {
var hook = new RequestHook() {
@Override
public RemotePromise<AnyPointer.Reader> send() {
return new RemotePromise<>(CompletableFuture.failedFuture(exc), null);
return new RemotePromise<>(CompletableFuture.failedFuture(exc),
new AnyPointer.Pipeline(PipelineHook.newBrokenPipeline(exc)));
}
@Override
@ -47,7 +48,7 @@ public interface Request<Params> {
@Override
public Request<AnyPointer.Builder> getTypelessRequest() {
return null;
return new AnyPointer.Request(message.getRoot(AnyPointer.factory), hook);
}
};
}

View file

@ -1,6 +1,7 @@
package org.capnproto;
import java.io.IOException;
import java.lang.ref.Reference;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.*;
@ -9,7 +10,7 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Consumer;
final class RpcState {
final class RpcState<VatId> {
private static int messageSizeHint() {
return 1 + RpcProtocol.Message.factory.structSize().total();
@ -19,12 +20,12 @@ final class RpcState {
return RpcProtocol.Exception.factory.structSize().total() + exc.getMessage().length();
}
private static int MESSAGE_TARGET_SIZE_HINT
private static final int MESSAGE_TARGET_SIZE_HINT
= RpcProtocol.MessageTarget.factory.structSize().total()
+ RpcProtocol.PromisedAnswer.factory.structSize().total()
+ 16;
private static int CAP_DESCRIPTOR_SIZE_HINT
private static final int CAP_DESCRIPTOR_SIZE_HINT
= RpcProtocol.CapDescriptor.factory.structSize().total()
+ RpcProtocol.PromisedAnswer.factory.structSize().total();
@ -63,7 +64,7 @@ final class RpcState {
}
}
private static final class QuestionRef extends WeakReference<Question> {
private final class QuestionRef extends WeakReference<Question> {
private final QuestionDisposer disposer;
@ -77,7 +78,7 @@ final class RpcState {
}
}
private final class Question {
private class Question {
CompletableFuture<RpcResponse> response = new CompletableFuture<>();
int[] paramExports = new int[0];
@ -146,8 +147,8 @@ final class RpcState {
public Iterator<Question> iterator() {
return this.slots.values()
.stream()
.map(ref -> ref.get())
.filter(question -> question != null)
.map(Reference::get)
.filter(Objects::nonNull)
.iterator();
}
@ -160,7 +161,7 @@ final class RpcState {
}
}
static final class Answer {
final class Answer {
final int answerId;
boolean active = false;
PipelineHook pipeline;
@ -228,28 +229,14 @@ final class RpcState {
}
}
private final ExportTable<Export> exports = new ExportTable<Export>() {
private final ExportTable<Export> exports = new ExportTable<>() {
@Override
Export newExportable(int id) {
return new Export(id);
}
};
/*
private final ExportTable<QuestionRef> questions = new ExportTable<>() {
@Override
QuestionRef newExportable(int id) {
return new QuestionRef(new Question(id));
}
};
*/
private final QuestionExportTable questions = new QuestionExportTable();
/*{
@Override
Question newExportable(int id) {
return new Question(id);
}
*/
private final ImportTable<Answer> answers = new ImportTable<>() {
@Override
@ -273,8 +260,8 @@ final class RpcState {
};
private final Map<ClientHook, Integer> exportsByCap = new HashMap<>();
private final Capability.Client bootstrapInterface;
private final VatNetwork.Connection connection;
private final BootstrapFactory<VatId> bootstrapFactory;
private final VatNetwork.Connection<VatId> connection;
private final CompletableFuture<java.lang.Void> onDisconnect;
private Throwable disconnected = null;
private CompletableFuture<java.lang.Void> messageReady = CompletableFuture.completedFuture(null);
@ -282,10 +269,10 @@ final class RpcState {
private final ReferenceQueue<Question> questionRefs = new ReferenceQueue<>();
private final ReferenceQueue<ImportClient> importRefs = new ReferenceQueue<>();
RpcState(Capability.Client bootstrapInterface,
VatNetwork.Connection connection,
RpcState(BootstrapFactory<VatId> bootstrapFactory,
VatNetwork.Connection<VatId> connection,
CompletableFuture<java.lang.Void> onDisconnect) {
this.bootstrapInterface = bootstrapInterface;
this.bootstrapFactory = bootstrapFactory;
this.connection = connection;
this.onDisconnect = onDisconnect;
this.messageLoop = this.doMessageLoop();
@ -445,6 +432,7 @@ final class RpcState {
private void handleMessage(IncomingRpcMessage message) throws RpcException {
var reader = message.getBody().getAs(RpcProtocol.Message.factory);
//System.out.println(reader.which());
switch (reader.which()) {
case UNIMPLEMENTED:
handleUnimplemented(reader.getUnimplemented());
@ -548,7 +536,8 @@ final class RpcState {
var payload = ret.initResults();
var content = payload.getContent().imbue(capTable);
content.setAsCap(bootstrapInterface);
var cap = this.bootstrapFactory.createFor(connection.getPeerVatId());
content.setAsCap(cap);
var caps = capTable.getTable();
var capHook = caps.length != 0
? caps[0]
@ -1193,7 +1182,7 @@ final class RpcState {
AnyPointer.Builder getResultsBuilder();
}
static class RpcResponseImpl implements RpcResponse {
class RpcResponseImpl implements RpcResponse {
private final Question question;
private final IncomingRpcMessage message;
private final AnyPointer.Reader results;

View file

@ -4,46 +4,81 @@ import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
public abstract class RpcSystem<VatId> {
public class RpcSystem<VatId extends StructReader> {
final VatNetwork<VatId> network;
final Capability.Client bootstrapInterface;
final Map<VatNetwork.Connection, RpcState> connections = new HashMap<>();
final CompletableFuture<java.lang.Void> messageLoop;
final CompletableFuture<java.lang.Void> acceptLoop;
private final VatNetwork<VatId> network;
private final BootstrapFactory<VatId> bootstrapFactory;
private final Map<VatNetwork.Connection<VatId>, RpcState<VatId>> connections = new HashMap<>();
private final CompletableFuture<java.lang.Void> messageLoop;
private final CompletableFuture<java.lang.Void> acceptLoop;
public RpcSystem(VatNetwork<VatId> network, Capability.Client bootstrapInterface) {
public RpcSystem(VatNetwork<VatId> network) {
this.network = network;
this.bootstrapInterface = bootstrapInterface;
this.bootstrapFactory = null;
this.acceptLoop = new CompletableFuture<>();
this.messageLoop = doMessageLoop();
}
public VatNetwork<VatId> getNetwork() {
return this.network;
}
public RpcSystem(VatNetwork<VatId> network,
Capability.Client bootstrapInterface) {
this(network, new BootstrapFactory<VatId>() {
@Override
public FromPointerReader<VatId> getVatIdFactory() {
return this.getVatIdFactory();
}
@Override
public Capability.Client createFor(VatId clientId) {
return bootstrapInterface;
}
});
}
public RpcSystem(VatNetwork<VatId> network,
BootstrapFactory<VatId> bootstrapFactory) {
this.network = network;
this.bootstrapFactory = bootstrapFactory;
this.acceptLoop = doAcceptLoop();
this.messageLoop = doMessageLoop();
}
public CompletableFuture<java.lang.Void> getMessageLoop() {
return this.messageLoop;
public Capability.Client bootstrap(VatId vatId) {
var connection = this.getNetwork().connect(vatId);
if (connection != null) {
var state = getConnectionState(connection);
var hook = state.restore();
return new Capability.Client(hook);
}
else if (this.bootstrapFactory != null) {
return this.bootstrapFactory.createFor(vatId);
}
else {
return new Capability.Client(Capability.newBrokenCap("No bootstrap interface available"));
}
}
private CompletableFuture<java.lang.Void> getAcceptLoop() {
return this.acceptLoop;
}
RpcState<VatId> getConnectionState(VatNetwork.Connection<VatId> connection) {
public void accept(VatNetwork.Connection connection) {
getConnectionState(connection);
}
RpcState getConnectionState(VatNetwork.Connection connection) {
var onDisconnect = new CompletableFuture<VatNetwork.Connection>()
var onDisconnect = new CompletableFuture<VatNetwork.Connection<VatId>>()
.thenAccept(lostConnection -> {
this.connections.remove(lostConnection);
});
return connections.computeIfAbsent(connection, key ->
new RpcState(bootstrapInterface, connection, onDisconnect));
new RpcState<VatId>(this.bootstrapFactory, connection, onDisconnect));
}
public void accept(VatNetwork.Connection<VatId> connection) {
getConnectionState(connection);
}
private CompletableFuture<java.lang.Void> doAcceptLoop() {
return this.network.baseAccept().thenCompose(connection -> {
return this.getNetwork().baseAccept().thenCompose(connection -> {
this.accept(connection);
return this.doAcceptLoop();
});
@ -56,4 +91,29 @@ public abstract class RpcSystem<VatId> {
}
return accept.thenCompose(x -> this.doMessageLoop());
}
public CompletableFuture<java.lang.Void> getMessageLoop() {
return this.messageLoop;
}
private CompletableFuture<java.lang.Void> getAcceptLoop() {
return this.acceptLoop;
}
public static <VatId extends StructReader>
RpcSystem<VatId> makeRpcClient(VatNetwork<VatId> network) {
return new RpcSystem<>(network);
}
public static <VatId extends StructReader>
RpcSystem<VatId> makeRpcServer(VatNetwork<VatId> network,
BootstrapFactory<VatId> bootstrapFactory) {
return new RpcSystem<>(network, bootstrapFactory);
}
public static <VatId extends StructReader>
RpcSystem<VatId> makeRpcServer(VatNetwork<VatId> network,
Capability.Client bootstrapInterface) {
return new RpcSystem<>(network, bootstrapInterface);
}
}

View file

@ -3,18 +3,20 @@ package org.capnproto;
public class TwoPartyRpcSystem
extends RpcSystem<RpcTwoPartyProtocol.VatId.Reader> {
private TwoPartyVatNetwork network;
public TwoPartyRpcSystem(TwoPartyVatNetwork network, Capability.Client bootstrapInterface) {
super(network, bootstrapInterface);
this.network = network;
}
public TwoPartyRpcSystem(TwoPartyVatNetwork network, Capability.Server bootstrapInterface) {
super(network, new Capability.Client(bootstrapInterface));
this.network = network;
}
public Capability.Client bootstrap(RpcTwoPartyProtocol.VatId.Reader vatId) {
var connection = this.network.baseConnect(vatId);
var state = getConnectionState(connection);
var hook = state.restore();
return new Capability.Client(hook);
@Override
public VatNetwork<RpcTwoPartyProtocol.VatId.Reader> getNetwork() {
return this.network;
}
}

View file

@ -4,9 +4,15 @@ import java.nio.channels.AsynchronousSocketChannel;
import java.util.List;
import java.util.concurrent.CompletableFuture;
public class TwoPartyVatNetwork
implements VatNetwork<RpcTwoPartyProtocol.VatId.Reader>,
VatNetwork.Connection {
VatNetwork.Connection<RpcTwoPartyProtocol.VatId.Reader> {
@Override
public CompletableFuture<Connection<RpcTwoPartyProtocol.VatId.Reader>> baseAccept() {
return this.accept();
}
public interface MessageTap {
void incoming(IncomingRpcMessage message, RpcTwoPartyProtocol.Side side);
@ -33,25 +39,22 @@ public class TwoPartyVatNetwork
return side;
}
public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() {
return peerVatId.getRoot(RpcTwoPartyProtocol.VatId.factory).asReader();
}
public void setTap(MessageTap tap) {
this.tap = tap;
}
public VatNetwork.Connection asConnection() {
public Connection asConnection() {
return this;
}
private Connection connect(RpcTwoPartyProtocol.VatId.Reader vatId) {
@Override
public Connection connect(RpcTwoPartyProtocol.VatId.Reader vatId) {
return vatId.getSide() != side
? this.asConnection()
: null;
}
private CompletableFuture<Connection> accept() {
public CompletableFuture<Connection<RpcTwoPartyProtocol.VatId.Reader>> accept() {
if (side == RpcTwoPartyProtocol.Side.SERVER & !accepted) {
accepted = true;
return CompletableFuture.completedFuture(this.asConnection());
@ -62,6 +65,10 @@ public class TwoPartyVatNetwork
}
}
public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() {
return this.peerVatId.getRoot(RpcTwoPartyProtocol.VatId.factory).asReader();
}
@Override
public OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize) {
return new OutgoingMessage(firstSegmentWordSize);
@ -111,16 +118,6 @@ public class TwoPartyVatNetwork
});
}
@Override
public Connection baseConnect(RpcTwoPartyProtocol.VatId.Reader hostId) {
return this.connect(hostId);
}
@Override
public CompletableFuture<Connection> baseAccept() {
return this.accept();
}
final class OutgoingMessage implements OutgoingRpcMessage {
private final MessageBuilder message;

View file

@ -2,9 +2,9 @@ package org.capnproto;
import java.util.concurrent.CompletableFuture;
public interface VatNetwork<VatId> {
interface Connection {
public interface VatNetwork<VatId>
{
interface Connection<VatId> {
default OutgoingRpcMessage newOutgoingMessage() {
return newOutgoingMessage(0);
}
@ -12,8 +12,13 @@ public interface VatNetwork<VatId> {
CompletableFuture<IncomingRpcMessage> receiveIncomingMessage();
CompletableFuture<java.lang.Void> onDisconnect();
CompletableFuture<java.lang.Void> shutdown();
VatId getPeerVatId();
}
Connection baseConnect(VatId hostId);
CompletableFuture<Connection> baseAccept();
CompletableFuture<Connection<VatId>> baseAccept();
//FromPointerReader<VatId> getVatIdFactory();
Connection<VatId> connect(VatId hostId);
}

View file

@ -21,8 +21,10 @@
package org.capnproto;
import org.capnproto.test.Test;
import org.junit.Assert;
import org.junit.Test;
import org.junit.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
@ -33,37 +35,7 @@ class Counter {
int value() { return count; }
}
class TestInterfaceImpl extends org.capnproto.test.Test.TestInterface.Server {
final Counter counter;
TestInterfaceImpl(Counter counter) {
this.counter = counter;
}
@Override
protected CompletableFuture<java.lang.Void> foo(CallContext<org.capnproto.test.Test.TestInterface.FooParams.Reader, org.capnproto.test.Test.TestInterface.FooResults.Builder> ctx) {
this.counter.inc();
var params = ctx.getParams();
var result = ctx.getResults();
Assert.assertEquals(123, params.getI());
Assert.assertTrue(params.getJ());
result.setX("foo");
return READY_NOW;
}
@Override
protected CompletableFuture<java.lang.Void> baz(CallContext<org.capnproto.test.Test.TestInterface.BazParams.Reader, org.capnproto.test.Test.TestInterface.BazResults.Builder> context) {
this.counter.inc();
var params = context.getParams();
TestUtil.checkTestMessage(params.getS());
context.releaseParams();
Assert.assertThrows(RpcException.class, () -> context.getParams());
return READY_NOW;
}
}
class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server {
class TestExtendsImpl extends Test.TestExtends2.Server {
final Counter counter;
@ -72,7 +44,7 @@ class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server {
}
@Override
protected CompletableFuture<java.lang.Void> foo(CallContext<org.capnproto.test.Test.TestInterface.FooParams.Reader, org.capnproto.test.Test.TestInterface.FooResults.Builder> context) {
protected CompletableFuture<java.lang.Void> foo(CallContext<Test.TestInterface.FooParams.Reader, Test.TestInterface.FooResults.Builder> context) {
counter.inc();
var params = context.getParams();
var result = context.getResults();
@ -83,7 +55,7 @@ class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server {
}
@Override
protected CompletableFuture<java.lang.Void> grault(CallContext<org.capnproto.test.Test.TestExtends.GraultParams.Reader, org.capnproto.test.Test.TestAllTypes.Builder> context) {
protected CompletableFuture<java.lang.Void> grault(CallContext<Test.TestExtends.GraultParams.Reader, Test.TestAllTypes.Builder> context) {
counter.inc();
context.releaseParams();
TestUtil.initTestMessage(context.getResults());
@ -91,50 +63,25 @@ class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server {
}
}
class TestPipelineImpl extends org.capnproto.test.Test.TestPipeline.Server {
class TestCallOrderImpl extends Test.TestCallOrder.Server {
final Counter counter;
TestPipelineImpl(Counter counter) {
this.counter = counter;
}
int count = 0;
@Override
protected CompletableFuture<java.lang.Void> getCap(CallContext<org.capnproto.test.Test.TestPipeline.GetCapParams.Reader, org.capnproto.test.Test.TestPipeline.GetCapResults.Builder> ctx) {
this.counter.inc();
var params = ctx.getParams();
Assert.assertEquals(234, params.getN());
var cap = params.getInCap();
ctx.releaseParams();
var request = cap.fooRequest();
var fooParams = request.getParams();
fooParams.setI(123);
fooParams.setJ(true);
return request.send().thenAccept(response -> {
Assert.assertEquals("foo", response.getX().toString());
var result = ctx.getResults();
result.setS("bar");
org.capnproto.test.Test.TestExtends.Server server = new TestExtendsImpl(this.counter);
result.initOutBox().setCap(server);
});
}
@Override
protected CompletableFuture<java.lang.Void> getAnyCap(CallContext<org.capnproto.test.Test.TestPipeline.GetAnyCapParams.Reader, org.capnproto.test.Test.TestPipeline.GetAnyCapResults.Builder> context) {
return super.getAnyCap(context);
protected CompletableFuture<java.lang.Void> getCallSequence(CallContext<Test.TestCallOrder.GetCallSequenceParams.Reader, Test.TestCallOrder.GetCallSequenceResults.Builder> context) {
var result = context.getResults();
result.setN(this.count++);
return READY_NOW;
}
}
public class CapabilityTest {
@Test
@org.junit.Test
public void testBasic() {
var callCount = new Counter();
var client = new org.capnproto.test.Test.TestInterface.Client(
new TestInterfaceImpl(callCount));
var client = new Test.TestInterface.Client(
new TestUtil.TestInterfaceImpl(callCount));
var request1 = client.fooRequest();
request1.getParams().setI(123);
@ -155,15 +102,15 @@ public class CapabilityTest {
});
}
@Test
@org.junit.Test
public void testInheritance() throws ExecutionException, InterruptedException {
var callCount = new Counter();
var client1 = new org.capnproto.test.Test.TestExtends.Client(
var client1 = new Test.TestExtends.Client(
new TestExtendsImpl(callCount));
org.capnproto.test.Test.TestInterface.Client client2 = client1;
var client = (org.capnproto.test.Test.TestExtends.Client)client2;
Test.TestInterface.Client client2 = client1;
var client = (Test.TestExtends.Client)client2;
var request1 = client.fooRequest();
request1.getParams().setI(321);
@ -183,26 +130,26 @@ public class CapabilityTest {
Assert.assertEquals(2, callCount.value());
}
@Test
@org.junit.Test
public void testPipelining() throws ExecutionException, InterruptedException {
var callCount = new Counter();
var chainedCallCount = new Counter();
var client = new org.capnproto.test.Test.TestPipeline.Client(
new TestPipelineImpl(callCount));
var client = new Test.TestPipeline.Client(
new TestUtil.TestPipelineImpl(callCount));
var request = client.getCapRequest();
var params = request.getParams();
params.setN(234);
params.setInCap(new org.capnproto.test.Test.TestInterface.Client(
new TestInterfaceImpl(chainedCallCount)));
params.setInCap(new Test.TestInterface.Client(
new TestUtil.TestInterfaceImpl(chainedCallCount)));
var promise = request.send();
var outbox = promise.getOutBox();
var pipelineRequest = outbox.getCap().fooRequest();
pipelineRequest.getParams().setI(321);
var pipelinePromise = pipelineRequest.send();
var pipelineRequest2 = new org.capnproto.test.Test.TestExtends.Client(promise.getOutBox().getCap()).graultRequest();
var pipelineRequest2 = new Test.TestExtends.Client(promise.getOutBox().getCap()).graultRequest();
var pipelinePromise2 = pipelineRequest2.send();
// Hmm, we have no means to defer the evaluation of callInternal. The best we can do is
@ -219,7 +166,7 @@ public class CapabilityTest {
Assert.assertEquals(1, chainedCallCount.value());
}
class TestThisCap extends org.capnproto.test.Test.TestInterface.Server {
class TestThisCap extends Test.TestInterface.Server {
Counter counter;
@ -227,29 +174,29 @@ public class CapabilityTest {
this.counter = counter;
}
org.capnproto.test.Test.TestInterface.Client getSelf() {
Test.TestInterface.Client getSelf() {
return this.thisCap();
}
@Override
protected CompletableFuture<java.lang.Void> bar(CallContext<org.capnproto.test.Test.TestInterface.BarParams.Reader, org.capnproto.test.Test.TestInterface.BarResults.Builder> context) {
protected CompletableFuture<java.lang.Void> bar(CallContext<Test.TestInterface.BarParams.Reader, Test.TestInterface.BarResults.Builder> context) {
this.counter.inc();
return READY_NOW;
}
}
@Test
@org.junit.Test
public void testGenerics() {
var factory = org.capnproto.test.Test.TestGenerics.newFactory(org.capnproto.test.Test.TestAllTypes.factory, AnyPointer.factory);
var factory = Test.TestGenerics.newFactory(Test.TestAllTypes.factory, AnyPointer.factory);
}
@Test
@org.junit.Test
public void thisCap() {
var callCount = new Counter();
var server = new TestThisCap(callCount);
var client = new org.capnproto.test.Test.TestInterface.Client(server);
var client = new Test.TestInterface.Client(server);
client.barRequest().send().join();
Assert.assertEquals(1, callCount.value());

View file

@ -8,8 +8,6 @@ import org.junit.Test;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
public class RpcStateTest {
@ -23,7 +21,7 @@ public class RpcStateTest {
}
}
class TestConnection implements VatNetwork.Connection {
class TestConnection implements VatNetwork.Connection<RpcTwoPartyProtocol.VatId.Reader> {
private CompletableFuture<IncomingRpcMessage> nextIncomingMessage = new CompletableFuture<>();
private final CompletableFuture<java.lang.Void> disconnect = new CompletableFuture<>();
@ -69,6 +67,11 @@ public class RpcStateTest {
this.disconnect.complete(null);
return this.disconnect.copy();
}
@Override
public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() {
return null;
}
}
TestConnection connection;
@ -80,7 +83,19 @@ public class RpcStateTest {
public void setUp() throws Exception {
this.connection = new TestConnection();
this.bootstrapInterface = new Capability.Client(Capability.newNullCap());
this.rpc = new RpcState(bootstrapInterface, connection, connection.disconnect);
var bootstrapFactory = new BootstrapFactory<RpcTwoPartyProtocol.VatId.Reader>() {
@Override
public FromPointerReader<RpcTwoPartyProtocol.VatId.Reader> getVatIdFactory() {
return RpcTwoPartyProtocol.VatId.factory;
}
@Override
public Capability.Client createFor(RpcTwoPartyProtocol.VatId.Reader clientId) {
return bootstrapInterface;
}
};
this.rpc = new RpcState<RpcTwoPartyProtocol.VatId.Reader>(bootstrapFactory, connection, connection.disconnect);
}
@After

View file

@ -1,9 +1,12 @@
package org.capnproto;
import org.capnproto.test.Test;
import org.junit.Assert;
import java.util.concurrent.CompletableFuture;
class TestUtil {
static void initTestMessage(org.capnproto.test.Test.TestAllTypes.Builder builder) {
static void initTestMessage(Test.TestAllTypes.Builder builder) {
builder.setVoidField(Void.VOID);
builder.setBoolField(true);
builder.setInt8Field((byte) -123);
@ -12,26 +15,165 @@ class TestUtil {
builder.setInt64Field(-123456789012345L);
builder.setUInt8Field((byte) 234);
builder.setUInt16Field((short) 45678);
builder.setUInt32Field((int) 3456789012l);
builder.setUInt32Field((int) 3456789012L);
builder.setUInt64Field(1234567890123456789L);
builder.setFloat32Field(1234.5f);
builder.setFloat64Field(-123e45);
builder.setTextField("foo");
}
static void checkTestMessage(org.capnproto.test.Test.TestAllTypes.Reader reader) {
static void checkTestMessage(Test.TestAllTypes.Reader reader) {
Assert.assertEquals(Void.VOID, reader.getVoidField());
Assert.assertTrue(reader.getBoolField());
Assert.assertEquals((byte)-123, reader.getInt8Field());
Assert.assertEquals((short)-12345, reader.getInt16Field());
Assert.assertEquals(-12345678, reader.getInt32Field());
Assert.assertEquals(-123456789012345l, reader.getInt64Field());
Assert.assertEquals(-123456789012345L, reader.getInt64Field());
Assert.assertEquals((byte)234, reader.getUInt8Field());
Assert.assertEquals((short)45678, reader.getUInt16Field());
Assert.assertEquals((int)3456789012l, reader.getUInt32Field());
Assert.assertEquals(1234567890123456789l, reader.getUInt64Field());
Assert.assertEquals((int) 3456789012L, reader.getUInt32Field());
Assert.assertEquals(1234567890123456789L, reader.getUInt64Field());
Assert.assertEquals(null, 1234.5f, reader.getFloat32Field(), 0.1f);
Assert.assertEquals(null, -123e45, reader.getFloat64Field(), 0.1f);
Assert.assertEquals("foo", reader.getTextField().toString());
}
static class TestInterfaceImpl extends Test.TestInterface.Server {
final Counter counter;
TestInterfaceImpl(Counter counter) {
this.counter = counter;
}
@Override
protected CompletableFuture<java.lang.Void> foo(CallContext<Test.TestInterface.FooParams.Reader, Test.TestInterface.FooResults.Builder> ctx) {
this.counter.inc();
var params = ctx.getParams();
var result = ctx.getResults();
Assert.assertEquals(123, params.getI());
Assert.assertTrue(params.getJ());
result.setX("foo");
return READY_NOW;
}
@Override
protected CompletableFuture<java.lang.Void> baz(CallContext<Test.TestInterface.BazParams.Reader, Test.TestInterface.BazResults.Builder> context) {
this.counter.inc();
var params = context.getParams();
checkTestMessage(params.getS());
context.releaseParams();
return READY_NOW;
}
}
static class TestTailCallerImpl extends Test.TestTailCaller.Server {
private final Counter count;
public TestTailCallerImpl(Counter count) {
this.count = count;
}
@Override
protected CompletableFuture<java.lang.Void> foo(CallContext<Test.TestTailCaller.FooParams.Reader, Test.TestTailCallee.TailResult.Builder> context) {
this.count.inc();
var params = context.getParams();
var tailRequest = params.getCallee().fooRequest();
tailRequest.getParams().setI(params.getI());
tailRequest.getParams().setT("from TestTailCaller");
return context.tailCall(tailRequest);
}
public int getCount() {
return this.count.value();
}
}
static class TestMoreStuffImpl extends Test.TestMoreStuff.Server {
final Counter callCount;
final Counter handleCount;
public TestMoreStuffImpl(Counter callCount, Counter handleCount) {
this.callCount = callCount;
this.handleCount = handleCount;
}
}
static class TestTailCalleeImpl extends Test.TestTailCallee.Server {
private final Counter count;
public TestTailCalleeImpl(Counter count) {
this.count = count;
}
@Override
protected CompletableFuture<java.lang.Void> foo(CallContext<Test.TestTailCallee.FooParams.Reader, Test.TestTailCallee.TailResult.Builder> context) {
this.count.inc();
var params = context.getParams();
var results = context.getResults();
results.setI(params.getI());
results.setT(params.getT());
results.setC(new TestCallOrderImpl());
return READY_NOW;
}
}
static class TestPipelineImpl extends Test.TestPipeline.Server {
final Counter callCount;
TestPipelineImpl(Counter callCount) {
this.callCount = callCount;
}
@Override
protected CompletableFuture<java.lang.Void> getCap(CallContext<Test.TestPipeline.GetCapParams.Reader, Test.TestPipeline.GetCapResults.Builder> ctx) {
this.callCount.inc();
var params = ctx.getParams();
Assert.assertEquals(234, params.getN());
var cap = params.getInCap();
ctx.releaseParams();
var request = cap.fooRequest();
var fooParams = request.getParams();
fooParams.setI(123);
fooParams.setJ(true);
return request.send().thenAccept(response -> {
Assert.assertEquals("foo", response.getX().toString());
var result = ctx.getResults();
result.setS("bar");
Test.TestExtends.Server server = new TestExtendsImpl(this.callCount);
result.initOutBox().setCap(server);
});
}
@Override
protected CompletableFuture<java.lang.Void> getAnyCap(CallContext<Test.TestPipeline.GetAnyCapParams.Reader, Test.TestPipeline.GetAnyCapResults.Builder> context) {
this.callCount.inc();
var params = context.getParams();
Assert.assertEquals(234, params.getN());
var cap = params.getInCap();
context.releaseParams();
var request = new Test.TestInterface.Client(cap).fooRequest();
request.getParams().setI(123);
request.getParams().setJ(true);
return request.send().thenAccept(response -> {
Assert.assertEquals("foo", response.getX().toString());
var result = context.getResults();
result.setS("bar");
result.initOutBox().setCap(new TestExtendsImpl(callCount));
});
}
}
}

View file

@ -21,6 +21,32 @@ struct TestAllTypes {
dataField @13 : Data;
}
struct TestSturdyRef {
hostId @0 :TestSturdyRefHostId;
objectId @1 :AnyPointer;
}
struct TestSturdyRefHostId {
host @0 :Text;
}
struct TestSturdyRefObjectId {
tag @0 :Tag;
enum Tag {
testInterface @0;
testExtends @1;
testPipeline @2;
testTailCallee @3;
testTailCaller @4;
testMoreStuff @5;
}
}
struct TestProvisionId {}
struct TestRecipientId {}
struct TestThirdPartyCapId {}
struct TestJoinResult {}
interface TestInterface {
foo @0 (i :UInt32, j :Bool) -> (x :Text);
bar @1 () -> ();
@ -48,6 +74,76 @@ interface TestPipeline {
}
}
interface TestCallOrder {
getCallSequence @0 (expected: UInt32) -> (n: UInt32);
# First call returns 0, next returns 1, ...
#
# The input `expected` is ignored but useful for disambiguating debug logs.
}
interface TestTailCallee {
struct TailResult {
i @0 :UInt32;
t @1 :Text;
c @2 :TestCallOrder;
}
foo @0 (i :Int32, t :Text) -> TailResult;
}
interface TestTailCaller {
foo @0 (i :Int32, callee :TestTailCallee) -> TestTailCallee.TailResult;
}
interface TestHandle {}
interface TestMoreStuff extends(TestCallOrder) {
# Catch-all type that contains lots of testing methods.
callFoo @0 (cap :TestInterface) -> (s: Text);
# Call `cap.foo()`, check the result, and return "bar".
callFooWhenResolved @1 (cap :TestInterface) -> (s: Text);
# Like callFoo but waits for `cap` to resolve first.
neverReturn @2 (cap :TestInterface) -> (capCopy :TestInterface);
# Doesn't return. You should cancel it.
hold @3 (cap :TestInterface) -> ();
# Returns immediately but holds on to the capability.
callHeld @4 () -> (s: Text);
# Calls the capability previously held using `hold` (and keeps holding it).
getHeld @5 () -> (cap :TestInterface);
# Returns the capability previously held using `hold` (and keeps holding it).
echo @6 (cap :TestCallOrder) -> (cap :TestCallOrder);
# Just returns the input cap.
expectCancel @7 (cap :TestInterface) -> ();
# evalLater()-loops forever, holding `cap`. Must be canceled.
methodWithDefaults @8 (a :Text, b :UInt32 = 123, c :Text = "foo") -> (d :Text, e :Text = "bar");
methodWithNullDefault @12 (a :Text, b :TestInterface = null);
getHandle @9 () -> (handle :TestHandle);
# Get a new handle. Tests have an out-of-band way to check the current number of live handles, so
# this can be used to test garbage collection.
getNull @10 () -> (nullCap :TestMoreStuff);
# Always returns a null capability.
getEnormousString @11 () -> (str :Text);
# Attempts to return an 100MB string. Should always fail.
writeToFd @13 (fdCap1 :TestInterface, fdCap2 :TestInterface)
-> (fdCap3 :TestInterface, secondFdPresent :Bool);
# Expects fdCap1 and fdCap2 wrap socket file descriptors. Writes "foo" to the first and "bar" to
# the second. Also creates a socketpair, writes "baz" to one end, and returns the other end.
}
struct TestGenerics(Foo, Bar) {
foo @0 :Foo;
rev @1 :TestGenerics(Bar, Foo);