add rpc bootstrap factory

This commit is contained in:
Vaci Koblizek 2020-11-02 21:39:58 +00:00
parent 9d023f0449
commit 054213a0ac
11 changed files with 446 additions and 184 deletions

View file

@ -0,0 +1,8 @@
package org.capnproto;
public interface BootstrapFactory<VatId> {
FromPointerReader<VatId> getVatIdFactory();
Capability.Client createFor(VatId clientId);
}

View file

@ -24,7 +24,8 @@ public interface Request<Params> {
var hook = new RequestHook() { var hook = new RequestHook() {
@Override @Override
public RemotePromise<AnyPointer.Reader> send() { public RemotePromise<AnyPointer.Reader> send() {
return new RemotePromise<>(CompletableFuture.failedFuture(exc), null); return new RemotePromise<>(CompletableFuture.failedFuture(exc),
new AnyPointer.Pipeline(PipelineHook.newBrokenPipeline(exc)));
} }
@Override @Override
@ -47,7 +48,7 @@ public interface Request<Params> {
@Override @Override
public Request<AnyPointer.Builder> getTypelessRequest() { public Request<AnyPointer.Builder> getTypelessRequest() {
return null; return new AnyPointer.Request(message.getRoot(AnyPointer.factory), hook);
} }
}; };
} }

View file

@ -1,6 +1,7 @@
package org.capnproto; package org.capnproto;
import java.io.IOException; import java.io.IOException;
import java.lang.ref.Reference;
import java.lang.ref.ReferenceQueue; import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference; import java.lang.ref.WeakReference;
import java.util.*; import java.util.*;
@ -9,7 +10,7 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import java.util.function.Consumer; import java.util.function.Consumer;
final class RpcState { final class RpcState<VatId> {
private static int messageSizeHint() { private static int messageSizeHint() {
return 1 + RpcProtocol.Message.factory.structSize().total(); return 1 + RpcProtocol.Message.factory.structSize().total();
@ -19,12 +20,12 @@ final class RpcState {
return RpcProtocol.Exception.factory.structSize().total() + exc.getMessage().length(); return RpcProtocol.Exception.factory.structSize().total() + exc.getMessage().length();
} }
private static int MESSAGE_TARGET_SIZE_HINT private static final int MESSAGE_TARGET_SIZE_HINT
= RpcProtocol.MessageTarget.factory.structSize().total() = RpcProtocol.MessageTarget.factory.structSize().total()
+ RpcProtocol.PromisedAnswer.factory.structSize().total() + RpcProtocol.PromisedAnswer.factory.structSize().total()
+ 16; + 16;
private static int CAP_DESCRIPTOR_SIZE_HINT private static final int CAP_DESCRIPTOR_SIZE_HINT
= RpcProtocol.CapDescriptor.factory.structSize().total() = RpcProtocol.CapDescriptor.factory.structSize().total()
+ RpcProtocol.PromisedAnswer.factory.structSize().total(); + RpcProtocol.PromisedAnswer.factory.structSize().total();
@ -63,7 +64,7 @@ final class RpcState {
} }
} }
private static final class QuestionRef extends WeakReference<Question> { private final class QuestionRef extends WeakReference<Question> {
private final QuestionDisposer disposer; private final QuestionDisposer disposer;
@ -77,7 +78,7 @@ final class RpcState {
} }
} }
private final class Question { private class Question {
CompletableFuture<RpcResponse> response = new CompletableFuture<>(); CompletableFuture<RpcResponse> response = new CompletableFuture<>();
int[] paramExports = new int[0]; int[] paramExports = new int[0];
@ -146,8 +147,8 @@ final class RpcState {
public Iterator<Question> iterator() { public Iterator<Question> iterator() {
return this.slots.values() return this.slots.values()
.stream() .stream()
.map(ref -> ref.get()) .map(Reference::get)
.filter(question -> question != null) .filter(Objects::nonNull)
.iterator(); .iterator();
} }
@ -160,7 +161,7 @@ final class RpcState {
} }
} }
static final class Answer { final class Answer {
final int answerId; final int answerId;
boolean active = false; boolean active = false;
PipelineHook pipeline; PipelineHook pipeline;
@ -228,28 +229,14 @@ final class RpcState {
} }
} }
private final ExportTable<Export> exports = new ExportTable<Export>() { private final ExportTable<Export> exports = new ExportTable<>() {
@Override @Override
Export newExportable(int id) { Export newExportable(int id) {
return new Export(id); return new Export(id);
} }
}; };
/*
private final ExportTable<QuestionRef> questions = new ExportTable<>() {
@Override
QuestionRef newExportable(int id) {
return new QuestionRef(new Question(id));
}
};
*/
private final QuestionExportTable questions = new QuestionExportTable(); private final QuestionExportTable questions = new QuestionExportTable();
/*{
@Override
Question newExportable(int id) {
return new Question(id);
}
*/
private final ImportTable<Answer> answers = new ImportTable<>() { private final ImportTable<Answer> answers = new ImportTable<>() {
@Override @Override
@ -273,8 +260,8 @@ final class RpcState {
}; };
private final Map<ClientHook, Integer> exportsByCap = new HashMap<>(); private final Map<ClientHook, Integer> exportsByCap = new HashMap<>();
private final Capability.Client bootstrapInterface; private final BootstrapFactory<VatId> bootstrapFactory;
private final VatNetwork.Connection connection; private final VatNetwork.Connection<VatId> connection;
private final CompletableFuture<java.lang.Void> onDisconnect; private final CompletableFuture<java.lang.Void> onDisconnect;
private Throwable disconnected = null; private Throwable disconnected = null;
private CompletableFuture<java.lang.Void> messageReady = CompletableFuture.completedFuture(null); private CompletableFuture<java.lang.Void> messageReady = CompletableFuture.completedFuture(null);
@ -282,10 +269,10 @@ final class RpcState {
private final ReferenceQueue<Question> questionRefs = new ReferenceQueue<>(); private final ReferenceQueue<Question> questionRefs = new ReferenceQueue<>();
private final ReferenceQueue<ImportClient> importRefs = new ReferenceQueue<>(); private final ReferenceQueue<ImportClient> importRefs = new ReferenceQueue<>();
RpcState(Capability.Client bootstrapInterface, RpcState(BootstrapFactory<VatId> bootstrapFactory,
VatNetwork.Connection connection, VatNetwork.Connection<VatId> connection,
CompletableFuture<java.lang.Void> onDisconnect) { CompletableFuture<java.lang.Void> onDisconnect) {
this.bootstrapInterface = bootstrapInterface; this.bootstrapFactory = bootstrapFactory;
this.connection = connection; this.connection = connection;
this.onDisconnect = onDisconnect; this.onDisconnect = onDisconnect;
this.messageLoop = this.doMessageLoop(); this.messageLoop = this.doMessageLoop();
@ -445,6 +432,7 @@ final class RpcState {
private void handleMessage(IncomingRpcMessage message) throws RpcException { private void handleMessage(IncomingRpcMessage message) throws RpcException {
var reader = message.getBody().getAs(RpcProtocol.Message.factory); var reader = message.getBody().getAs(RpcProtocol.Message.factory);
//System.out.println(reader.which());
switch (reader.which()) { switch (reader.which()) {
case UNIMPLEMENTED: case UNIMPLEMENTED:
handleUnimplemented(reader.getUnimplemented()); handleUnimplemented(reader.getUnimplemented());
@ -548,7 +536,8 @@ final class RpcState {
var payload = ret.initResults(); var payload = ret.initResults();
var content = payload.getContent().imbue(capTable); var content = payload.getContent().imbue(capTable);
content.setAsCap(bootstrapInterface); var cap = this.bootstrapFactory.createFor(connection.getPeerVatId());
content.setAsCap(cap);
var caps = capTable.getTable(); var caps = capTable.getTable();
var capHook = caps.length != 0 var capHook = caps.length != 0
? caps[0] ? caps[0]
@ -1193,7 +1182,7 @@ final class RpcState {
AnyPointer.Builder getResultsBuilder(); AnyPointer.Builder getResultsBuilder();
} }
static class RpcResponseImpl implements RpcResponse { class RpcResponseImpl implements RpcResponse {
private final Question question; private final Question question;
private final IncomingRpcMessage message; private final IncomingRpcMessage message;
private final AnyPointer.Reader results; private final AnyPointer.Reader results;

View file

@ -4,46 +4,81 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
public abstract class RpcSystem<VatId> { public class RpcSystem<VatId extends StructReader> {
final VatNetwork<VatId> network; private final VatNetwork<VatId> network;
final Capability.Client bootstrapInterface; private final BootstrapFactory<VatId> bootstrapFactory;
final Map<VatNetwork.Connection, RpcState> connections = new HashMap<>(); private final Map<VatNetwork.Connection<VatId>, RpcState<VatId>> connections = new HashMap<>();
final CompletableFuture<java.lang.Void> messageLoop; private final CompletableFuture<java.lang.Void> messageLoop;
final CompletableFuture<java.lang.Void> acceptLoop; private final CompletableFuture<java.lang.Void> acceptLoop;
public RpcSystem(VatNetwork<VatId> network, Capability.Client bootstrapInterface) { public RpcSystem(VatNetwork<VatId> network) {
this.network = network; this.network = network;
this.bootstrapInterface = bootstrapInterface; this.bootstrapFactory = null;
this.acceptLoop = new CompletableFuture<>();
this.messageLoop = doMessageLoop();
}
public VatNetwork<VatId> getNetwork() {
return this.network;
}
public RpcSystem(VatNetwork<VatId> network,
Capability.Client bootstrapInterface) {
this(network, new BootstrapFactory<VatId>() {
@Override
public FromPointerReader<VatId> getVatIdFactory() {
return this.getVatIdFactory();
}
@Override
public Capability.Client createFor(VatId clientId) {
return bootstrapInterface;
}
});
}
public RpcSystem(VatNetwork<VatId> network,
BootstrapFactory<VatId> bootstrapFactory) {
this.network = network;
this.bootstrapFactory = bootstrapFactory;
this.acceptLoop = doAcceptLoop(); this.acceptLoop = doAcceptLoop();
this.messageLoop = doMessageLoop(); this.messageLoop = doMessageLoop();
} }
public CompletableFuture<java.lang.Void> getMessageLoop() { public Capability.Client bootstrap(VatId vatId) {
return this.messageLoop; var connection = this.getNetwork().connect(vatId);
if (connection != null) {
var state = getConnectionState(connection);
var hook = state.restore();
return new Capability.Client(hook);
}
else if (this.bootstrapFactory != null) {
return this.bootstrapFactory.createFor(vatId);
}
else {
return new Capability.Client(Capability.newBrokenCap("No bootstrap interface available"));
}
} }
private CompletableFuture<java.lang.Void> getAcceptLoop() { RpcState<VatId> getConnectionState(VatNetwork.Connection<VatId> connection) {
return this.acceptLoop;
}
public void accept(VatNetwork.Connection connection) { var onDisconnect = new CompletableFuture<VatNetwork.Connection<VatId>>()
getConnectionState(connection);
}
RpcState getConnectionState(VatNetwork.Connection connection) {
var onDisconnect = new CompletableFuture<VatNetwork.Connection>()
.thenAccept(lostConnection -> { .thenAccept(lostConnection -> {
this.connections.remove(lostConnection); this.connections.remove(lostConnection);
}); });
return connections.computeIfAbsent(connection, key -> return connections.computeIfAbsent(connection, key ->
new RpcState(bootstrapInterface, connection, onDisconnect)); new RpcState<VatId>(this.bootstrapFactory, connection, onDisconnect));
}
public void accept(VatNetwork.Connection<VatId> connection) {
getConnectionState(connection);
} }
private CompletableFuture<java.lang.Void> doAcceptLoop() { private CompletableFuture<java.lang.Void> doAcceptLoop() {
return this.network.baseAccept().thenCompose(connection -> { return this.getNetwork().baseAccept().thenCompose(connection -> {
this.accept(connection); this.accept(connection);
return this.doAcceptLoop(); return this.doAcceptLoop();
}); });
@ -56,4 +91,29 @@ public abstract class RpcSystem<VatId> {
} }
return accept.thenCompose(x -> this.doMessageLoop()); return accept.thenCompose(x -> this.doMessageLoop());
} }
public CompletableFuture<java.lang.Void> getMessageLoop() {
return this.messageLoop;
}
private CompletableFuture<java.lang.Void> getAcceptLoop() {
return this.acceptLoop;
}
public static <VatId extends StructReader>
RpcSystem<VatId> makeRpcClient(VatNetwork<VatId> network) {
return new RpcSystem<>(network);
}
public static <VatId extends StructReader>
RpcSystem<VatId> makeRpcServer(VatNetwork<VatId> network,
BootstrapFactory<VatId> bootstrapFactory) {
return new RpcSystem<>(network, bootstrapFactory);
}
public static <VatId extends StructReader>
RpcSystem<VatId> makeRpcServer(VatNetwork<VatId> network,
Capability.Client bootstrapInterface) {
return new RpcSystem<>(network, bootstrapInterface);
}
} }

View file

@ -3,18 +3,20 @@ package org.capnproto;
public class TwoPartyRpcSystem public class TwoPartyRpcSystem
extends RpcSystem<RpcTwoPartyProtocol.VatId.Reader> { extends RpcSystem<RpcTwoPartyProtocol.VatId.Reader> {
private TwoPartyVatNetwork network;
public TwoPartyRpcSystem(TwoPartyVatNetwork network, Capability.Client bootstrapInterface) { public TwoPartyRpcSystem(TwoPartyVatNetwork network, Capability.Client bootstrapInterface) {
super(network, bootstrapInterface); super(network, bootstrapInterface);
this.network = network;
} }
public TwoPartyRpcSystem(TwoPartyVatNetwork network, Capability.Server bootstrapInterface) { public TwoPartyRpcSystem(TwoPartyVatNetwork network, Capability.Server bootstrapInterface) {
super(network, new Capability.Client(bootstrapInterface)); super(network, new Capability.Client(bootstrapInterface));
this.network = network;
} }
public Capability.Client bootstrap(RpcTwoPartyProtocol.VatId.Reader vatId) { @Override
var connection = this.network.baseConnect(vatId); public VatNetwork<RpcTwoPartyProtocol.VatId.Reader> getNetwork() {
var state = getConnectionState(connection); return this.network;
var hook = state.restore();
return new Capability.Client(hook);
} }
} }

View file

@ -4,9 +4,15 @@ import java.nio.channels.AsynchronousSocketChannel;
import java.util.List; import java.util.List;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
public class TwoPartyVatNetwork public class TwoPartyVatNetwork
implements VatNetwork<RpcTwoPartyProtocol.VatId.Reader>, implements VatNetwork<RpcTwoPartyProtocol.VatId.Reader>,
VatNetwork.Connection { VatNetwork.Connection<RpcTwoPartyProtocol.VatId.Reader> {
@Override
public CompletableFuture<Connection<RpcTwoPartyProtocol.VatId.Reader>> baseAccept() {
return this.accept();
}
public interface MessageTap { public interface MessageTap {
void incoming(IncomingRpcMessage message, RpcTwoPartyProtocol.Side side); void incoming(IncomingRpcMessage message, RpcTwoPartyProtocol.Side side);
@ -33,25 +39,22 @@ public class TwoPartyVatNetwork
return side; return side;
} }
public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() {
return peerVatId.getRoot(RpcTwoPartyProtocol.VatId.factory).asReader();
}
public void setTap(MessageTap tap) { public void setTap(MessageTap tap) {
this.tap = tap; this.tap = tap;
} }
public VatNetwork.Connection asConnection() { public Connection asConnection() {
return this; return this;
} }
private Connection connect(RpcTwoPartyProtocol.VatId.Reader vatId) { @Override
public Connection connect(RpcTwoPartyProtocol.VatId.Reader vatId) {
return vatId.getSide() != side return vatId.getSide() != side
? this.asConnection() ? this.asConnection()
: null; : null;
} }
private CompletableFuture<Connection> accept() { public CompletableFuture<Connection<RpcTwoPartyProtocol.VatId.Reader>> accept() {
if (side == RpcTwoPartyProtocol.Side.SERVER & !accepted) { if (side == RpcTwoPartyProtocol.Side.SERVER & !accepted) {
accepted = true; accepted = true;
return CompletableFuture.completedFuture(this.asConnection()); return CompletableFuture.completedFuture(this.asConnection());
@ -62,6 +65,10 @@ public class TwoPartyVatNetwork
} }
} }
public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() {
return this.peerVatId.getRoot(RpcTwoPartyProtocol.VatId.factory).asReader();
}
@Override @Override
public OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize) { public OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize) {
return new OutgoingMessage(firstSegmentWordSize); return new OutgoingMessage(firstSegmentWordSize);
@ -111,16 +118,6 @@ public class TwoPartyVatNetwork
}); });
} }
@Override
public Connection baseConnect(RpcTwoPartyProtocol.VatId.Reader hostId) {
return this.connect(hostId);
}
@Override
public CompletableFuture<Connection> baseAccept() {
return this.accept();
}
final class OutgoingMessage implements OutgoingRpcMessage { final class OutgoingMessage implements OutgoingRpcMessage {
private final MessageBuilder message; private final MessageBuilder message;

View file

@ -2,18 +2,23 @@ package org.capnproto;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
public interface VatNetwork<VatId> { public interface VatNetwork<VatId>
{
interface Connection<VatId> {
default OutgoingRpcMessage newOutgoingMessage() {
return newOutgoingMessage(0);
}
OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize);
CompletableFuture<IncomingRpcMessage> receiveIncomingMessage();
CompletableFuture<java.lang.Void> onDisconnect();
CompletableFuture<java.lang.Void> shutdown();
VatId getPeerVatId();
}
interface Connection { CompletableFuture<Connection<VatId>> baseAccept();
default OutgoingRpcMessage newOutgoingMessage() {
return newOutgoingMessage(0);
}
OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize);
CompletableFuture<IncomingRpcMessage> receiveIncomingMessage();
CompletableFuture<java.lang.Void> onDisconnect();
CompletableFuture<java.lang.Void> shutdown();
}
Connection baseConnect(VatId hostId); //FromPointerReader<VatId> getVatIdFactory();
CompletableFuture<Connection> baseAccept();
Connection<VatId> connect(VatId hostId);
} }

View file

@ -21,8 +21,10 @@
package org.capnproto; package org.capnproto;
import org.capnproto.test.Test;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.*;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -33,37 +35,7 @@ class Counter {
int value() { return count; } int value() { return count; }
} }
class TestInterfaceImpl extends org.capnproto.test.Test.TestInterface.Server { class TestExtendsImpl extends Test.TestExtends2.Server {
final Counter counter;
TestInterfaceImpl(Counter counter) {
this.counter = counter;
}
@Override
protected CompletableFuture<java.lang.Void> foo(CallContext<org.capnproto.test.Test.TestInterface.FooParams.Reader, org.capnproto.test.Test.TestInterface.FooResults.Builder> ctx) {
this.counter.inc();
var params = ctx.getParams();
var result = ctx.getResults();
Assert.assertEquals(123, params.getI());
Assert.assertTrue(params.getJ());
result.setX("foo");
return READY_NOW;
}
@Override
protected CompletableFuture<java.lang.Void> baz(CallContext<org.capnproto.test.Test.TestInterface.BazParams.Reader, org.capnproto.test.Test.TestInterface.BazResults.Builder> context) {
this.counter.inc();
var params = context.getParams();
TestUtil.checkTestMessage(params.getS());
context.releaseParams();
Assert.assertThrows(RpcException.class, () -> context.getParams());
return READY_NOW;
}
}
class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server {
final Counter counter; final Counter counter;
@ -72,7 +44,7 @@ class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server {
} }
@Override @Override
protected CompletableFuture<java.lang.Void> foo(CallContext<org.capnproto.test.Test.TestInterface.FooParams.Reader, org.capnproto.test.Test.TestInterface.FooResults.Builder> context) { protected CompletableFuture<java.lang.Void> foo(CallContext<Test.TestInterface.FooParams.Reader, Test.TestInterface.FooResults.Builder> context) {
counter.inc(); counter.inc();
var params = context.getParams(); var params = context.getParams();
var result = context.getResults(); var result = context.getResults();
@ -83,7 +55,7 @@ class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server {
} }
@Override @Override
protected CompletableFuture<java.lang.Void> grault(CallContext<org.capnproto.test.Test.TestExtends.GraultParams.Reader, org.capnproto.test.Test.TestAllTypes.Builder> context) { protected CompletableFuture<java.lang.Void> grault(CallContext<Test.TestExtends.GraultParams.Reader, Test.TestAllTypes.Builder> context) {
counter.inc(); counter.inc();
context.releaseParams(); context.releaseParams();
TestUtil.initTestMessage(context.getResults()); TestUtil.initTestMessage(context.getResults());
@ -91,50 +63,25 @@ class TestExtendsImpl extends org.capnproto.test.Test.TestExtends2.Server {
} }
} }
class TestPipelineImpl extends org.capnproto.test.Test.TestPipeline.Server { class TestCallOrderImpl extends Test.TestCallOrder.Server {
final Counter counter; int count = 0;
TestPipelineImpl(Counter counter) {
this.counter = counter;
}
@Override @Override
protected CompletableFuture<java.lang.Void> getCap(CallContext<org.capnproto.test.Test.TestPipeline.GetCapParams.Reader, org.capnproto.test.Test.TestPipeline.GetCapResults.Builder> ctx) { protected CompletableFuture<java.lang.Void> getCallSequence(CallContext<Test.TestCallOrder.GetCallSequenceParams.Reader, Test.TestCallOrder.GetCallSequenceResults.Builder> context) {
this.counter.inc(); var result = context.getResults();
var params = ctx.getParams(); result.setN(this.count++);
Assert.assertEquals(234, params.getN()); return READY_NOW;
var cap = params.getInCap();
ctx.releaseParams();
var request = cap.fooRequest();
var fooParams = request.getParams();
fooParams.setI(123);
fooParams.setJ(true);
return request.send().thenAccept(response -> {
Assert.assertEquals("foo", response.getX().toString());
var result = ctx.getResults();
result.setS("bar");
org.capnproto.test.Test.TestExtends.Server server = new TestExtendsImpl(this.counter);
result.initOutBox().setCap(server);
});
}
@Override
protected CompletableFuture<java.lang.Void> getAnyCap(CallContext<org.capnproto.test.Test.TestPipeline.GetAnyCapParams.Reader, org.capnproto.test.Test.TestPipeline.GetAnyCapResults.Builder> context) {
return super.getAnyCap(context);
} }
} }
public class CapabilityTest { public class CapabilityTest {
@Test @org.junit.Test
public void testBasic() { public void testBasic() {
var callCount = new Counter(); var callCount = new Counter();
var client = new org.capnproto.test.Test.TestInterface.Client( var client = new Test.TestInterface.Client(
new TestInterfaceImpl(callCount)); new TestUtil.TestInterfaceImpl(callCount));
var request1 = client.fooRequest(); var request1 = client.fooRequest();
request1.getParams().setI(123); request1.getParams().setI(123);
@ -155,15 +102,15 @@ public class CapabilityTest {
}); });
} }
@Test @org.junit.Test
public void testInheritance() throws ExecutionException, InterruptedException { public void testInheritance() throws ExecutionException, InterruptedException {
var callCount = new Counter(); var callCount = new Counter();
var client1 = new org.capnproto.test.Test.TestExtends.Client( var client1 = new Test.TestExtends.Client(
new TestExtendsImpl(callCount)); new TestExtendsImpl(callCount));
org.capnproto.test.Test.TestInterface.Client client2 = client1; Test.TestInterface.Client client2 = client1;
var client = (org.capnproto.test.Test.TestExtends.Client)client2; var client = (Test.TestExtends.Client)client2;
var request1 = client.fooRequest(); var request1 = client.fooRequest();
request1.getParams().setI(321); request1.getParams().setI(321);
@ -183,26 +130,26 @@ public class CapabilityTest {
Assert.assertEquals(2, callCount.value()); Assert.assertEquals(2, callCount.value());
} }
@Test @org.junit.Test
public void testPipelining() throws ExecutionException, InterruptedException { public void testPipelining() throws ExecutionException, InterruptedException {
var callCount = new Counter(); var callCount = new Counter();
var chainedCallCount = new Counter(); var chainedCallCount = new Counter();
var client = new org.capnproto.test.Test.TestPipeline.Client( var client = new Test.TestPipeline.Client(
new TestPipelineImpl(callCount)); new TestUtil.TestPipelineImpl(callCount));
var request = client.getCapRequest(); var request = client.getCapRequest();
var params = request.getParams(); var params = request.getParams();
params.setN(234); params.setN(234);
params.setInCap(new org.capnproto.test.Test.TestInterface.Client( params.setInCap(new Test.TestInterface.Client(
new TestInterfaceImpl(chainedCallCount))); new TestUtil.TestInterfaceImpl(chainedCallCount)));
var promise = request.send(); var promise = request.send();
var outbox = promise.getOutBox(); var outbox = promise.getOutBox();
var pipelineRequest = outbox.getCap().fooRequest(); var pipelineRequest = outbox.getCap().fooRequest();
pipelineRequest.getParams().setI(321); pipelineRequest.getParams().setI(321);
var pipelinePromise = pipelineRequest.send(); var pipelinePromise = pipelineRequest.send();
var pipelineRequest2 = new org.capnproto.test.Test.TestExtends.Client(promise.getOutBox().getCap()).graultRequest(); var pipelineRequest2 = new Test.TestExtends.Client(promise.getOutBox().getCap()).graultRequest();
var pipelinePromise2 = pipelineRequest2.send(); var pipelinePromise2 = pipelineRequest2.send();
// Hmm, we have no means to defer the evaluation of callInternal. The best we can do is // Hmm, we have no means to defer the evaluation of callInternal. The best we can do is
@ -219,7 +166,7 @@ public class CapabilityTest {
Assert.assertEquals(1, chainedCallCount.value()); Assert.assertEquals(1, chainedCallCount.value());
} }
class TestThisCap extends org.capnproto.test.Test.TestInterface.Server { class TestThisCap extends Test.TestInterface.Server {
Counter counter; Counter counter;
@ -227,29 +174,29 @@ public class CapabilityTest {
this.counter = counter; this.counter = counter;
} }
org.capnproto.test.Test.TestInterface.Client getSelf() { Test.TestInterface.Client getSelf() {
return this.thisCap(); return this.thisCap();
} }
@Override @Override
protected CompletableFuture<java.lang.Void> bar(CallContext<org.capnproto.test.Test.TestInterface.BarParams.Reader, org.capnproto.test.Test.TestInterface.BarResults.Builder> context) { protected CompletableFuture<java.lang.Void> bar(CallContext<Test.TestInterface.BarParams.Reader, Test.TestInterface.BarResults.Builder> context) {
this.counter.inc(); this.counter.inc();
return READY_NOW; return READY_NOW;
} }
} }
@Test @org.junit.Test
public void testGenerics() { public void testGenerics() {
var factory = org.capnproto.test.Test.TestGenerics.newFactory(org.capnproto.test.Test.TestAllTypes.factory, AnyPointer.factory); var factory = Test.TestGenerics.newFactory(Test.TestAllTypes.factory, AnyPointer.factory);
} }
@Test @org.junit.Test
public void thisCap() { public void thisCap() {
var callCount = new Counter(); var callCount = new Counter();
var server = new TestThisCap(callCount); var server = new TestThisCap(callCount);
var client = new org.capnproto.test.Test.TestInterface.Client(server); var client = new Test.TestInterface.Client(server);
client.barRequest().send().join(); client.barRequest().send().join();
Assert.assertEquals(1, callCount.value()); Assert.assertEquals(1, callCount.value());

View file

@ -8,8 +8,6 @@ import org.junit.Test;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
public class RpcStateTest { public class RpcStateTest {
@ -23,7 +21,7 @@ public class RpcStateTest {
} }
} }
class TestConnection implements VatNetwork.Connection { class TestConnection implements VatNetwork.Connection<RpcTwoPartyProtocol.VatId.Reader> {
private CompletableFuture<IncomingRpcMessage> nextIncomingMessage = new CompletableFuture<>(); private CompletableFuture<IncomingRpcMessage> nextIncomingMessage = new CompletableFuture<>();
private final CompletableFuture<java.lang.Void> disconnect = new CompletableFuture<>(); private final CompletableFuture<java.lang.Void> disconnect = new CompletableFuture<>();
@ -69,6 +67,11 @@ public class RpcStateTest {
this.disconnect.complete(null); this.disconnect.complete(null);
return this.disconnect.copy(); return this.disconnect.copy();
} }
@Override
public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() {
return null;
}
} }
TestConnection connection; TestConnection connection;
@ -80,7 +83,19 @@ public class RpcStateTest {
public void setUp() throws Exception { public void setUp() throws Exception {
this.connection = new TestConnection(); this.connection = new TestConnection();
this.bootstrapInterface = new Capability.Client(Capability.newNullCap()); this.bootstrapInterface = new Capability.Client(Capability.newNullCap());
this.rpc = new RpcState(bootstrapInterface, connection, connection.disconnect); var bootstrapFactory = new BootstrapFactory<RpcTwoPartyProtocol.VatId.Reader>() {
@Override
public FromPointerReader<RpcTwoPartyProtocol.VatId.Reader> getVatIdFactory() {
return RpcTwoPartyProtocol.VatId.factory;
}
@Override
public Capability.Client createFor(RpcTwoPartyProtocol.VatId.Reader clientId) {
return bootstrapInterface;
}
};
this.rpc = new RpcState<RpcTwoPartyProtocol.VatId.Reader>(bootstrapFactory, connection, connection.disconnect);
} }
@After @After

View file

@ -1,9 +1,12 @@
package org.capnproto; package org.capnproto;
import org.capnproto.test.Test;
import org.junit.Assert; import org.junit.Assert;
import java.util.concurrent.CompletableFuture;
class TestUtil { class TestUtil {
static void initTestMessage(org.capnproto.test.Test.TestAllTypes.Builder builder) { static void initTestMessage(Test.TestAllTypes.Builder builder) {
builder.setVoidField(Void.VOID); builder.setVoidField(Void.VOID);
builder.setBoolField(true); builder.setBoolField(true);
builder.setInt8Field((byte) -123); builder.setInt8Field((byte) -123);
@ -12,26 +15,165 @@ class TestUtil {
builder.setInt64Field(-123456789012345L); builder.setInt64Field(-123456789012345L);
builder.setUInt8Field((byte) 234); builder.setUInt8Field((byte) 234);
builder.setUInt16Field((short) 45678); builder.setUInt16Field((short) 45678);
builder.setUInt32Field((int) 3456789012l); builder.setUInt32Field((int) 3456789012L);
builder.setUInt64Field(1234567890123456789L); builder.setUInt64Field(1234567890123456789L);
builder.setFloat32Field(1234.5f); builder.setFloat32Field(1234.5f);
builder.setFloat64Field(-123e45); builder.setFloat64Field(-123e45);
builder.setTextField("foo"); builder.setTextField("foo");
} }
static void checkTestMessage(org.capnproto.test.Test.TestAllTypes.Reader reader) { static void checkTestMessage(Test.TestAllTypes.Reader reader) {
Assert.assertEquals(Void.VOID, reader.getVoidField()); Assert.assertEquals(Void.VOID, reader.getVoidField());
Assert.assertTrue(reader.getBoolField()); Assert.assertTrue(reader.getBoolField());
Assert.assertEquals((byte)-123, reader.getInt8Field()); Assert.assertEquals((byte)-123, reader.getInt8Field());
Assert.assertEquals((short)-12345, reader.getInt16Field()); Assert.assertEquals((short)-12345, reader.getInt16Field());
Assert.assertEquals(-12345678, reader.getInt32Field()); Assert.assertEquals(-12345678, reader.getInt32Field());
Assert.assertEquals(-123456789012345l, reader.getInt64Field()); Assert.assertEquals(-123456789012345L, reader.getInt64Field());
Assert.assertEquals((byte)234, reader.getUInt8Field()); Assert.assertEquals((byte)234, reader.getUInt8Field());
Assert.assertEquals((short)45678, reader.getUInt16Field()); Assert.assertEquals((short)45678, reader.getUInt16Field());
Assert.assertEquals((int)3456789012l, reader.getUInt32Field()); Assert.assertEquals((int) 3456789012L, reader.getUInt32Field());
Assert.assertEquals(1234567890123456789l, reader.getUInt64Field()); Assert.assertEquals(1234567890123456789L, reader.getUInt64Field());
Assert.assertEquals(null, 1234.5f, reader.getFloat32Field(), 0.1f); Assert.assertEquals(null, 1234.5f, reader.getFloat32Field(), 0.1f);
Assert.assertEquals(null, -123e45, reader.getFloat64Field(), 0.1f); Assert.assertEquals(null, -123e45, reader.getFloat64Field(), 0.1f);
Assert.assertEquals("foo", reader.getTextField().toString()); Assert.assertEquals("foo", reader.getTextField().toString());
} }
static class TestInterfaceImpl extends Test.TestInterface.Server {
final Counter counter;
TestInterfaceImpl(Counter counter) {
this.counter = counter;
}
@Override
protected CompletableFuture<java.lang.Void> foo(CallContext<Test.TestInterface.FooParams.Reader, Test.TestInterface.FooResults.Builder> ctx) {
this.counter.inc();
var params = ctx.getParams();
var result = ctx.getResults();
Assert.assertEquals(123, params.getI());
Assert.assertTrue(params.getJ());
result.setX("foo");
return READY_NOW;
}
@Override
protected CompletableFuture<java.lang.Void> baz(CallContext<Test.TestInterface.BazParams.Reader, Test.TestInterface.BazResults.Builder> context) {
this.counter.inc();
var params = context.getParams();
checkTestMessage(params.getS());
context.releaseParams();
return READY_NOW;
}
}
static class TestTailCallerImpl extends Test.TestTailCaller.Server {
private final Counter count;
public TestTailCallerImpl(Counter count) {
this.count = count;
}
@Override
protected CompletableFuture<java.lang.Void> foo(CallContext<Test.TestTailCaller.FooParams.Reader, Test.TestTailCallee.TailResult.Builder> context) {
this.count.inc();
var params = context.getParams();
var tailRequest = params.getCallee().fooRequest();
tailRequest.getParams().setI(params.getI());
tailRequest.getParams().setT("from TestTailCaller");
return context.tailCall(tailRequest);
}
public int getCount() {
return this.count.value();
}
}
static class TestMoreStuffImpl extends Test.TestMoreStuff.Server {
final Counter callCount;
final Counter handleCount;
public TestMoreStuffImpl(Counter callCount, Counter handleCount) {
this.callCount = callCount;
this.handleCount = handleCount;
}
}
static class TestTailCalleeImpl extends Test.TestTailCallee.Server {
private final Counter count;
public TestTailCalleeImpl(Counter count) {
this.count = count;
}
@Override
protected CompletableFuture<java.lang.Void> foo(CallContext<Test.TestTailCallee.FooParams.Reader, Test.TestTailCallee.TailResult.Builder> context) {
this.count.inc();
var params = context.getParams();
var results = context.getResults();
results.setI(params.getI());
results.setT(params.getT());
results.setC(new TestCallOrderImpl());
return READY_NOW;
}
}
static class TestPipelineImpl extends Test.TestPipeline.Server {
final Counter callCount;
TestPipelineImpl(Counter callCount) {
this.callCount = callCount;
}
@Override
protected CompletableFuture<java.lang.Void> getCap(CallContext<Test.TestPipeline.GetCapParams.Reader, Test.TestPipeline.GetCapResults.Builder> ctx) {
this.callCount.inc();
var params = ctx.getParams();
Assert.assertEquals(234, params.getN());
var cap = params.getInCap();
ctx.releaseParams();
var request = cap.fooRequest();
var fooParams = request.getParams();
fooParams.setI(123);
fooParams.setJ(true);
return request.send().thenAccept(response -> {
Assert.assertEquals("foo", response.getX().toString());
var result = ctx.getResults();
result.setS("bar");
Test.TestExtends.Server server = new TestExtendsImpl(this.callCount);
result.initOutBox().setCap(server);
});
}
@Override
protected CompletableFuture<java.lang.Void> getAnyCap(CallContext<Test.TestPipeline.GetAnyCapParams.Reader, Test.TestPipeline.GetAnyCapResults.Builder> context) {
this.callCount.inc();
var params = context.getParams();
Assert.assertEquals(234, params.getN());
var cap = params.getInCap();
context.releaseParams();
var request = new Test.TestInterface.Client(cap).fooRequest();
request.getParams().setI(123);
request.getParams().setJ(true);
return request.send().thenAccept(response -> {
Assert.assertEquals("foo", response.getX().toString());
var result = context.getResults();
result.setS("bar");
result.initOutBox().setCap(new TestExtendsImpl(callCount));
});
}
}
} }

View file

@ -21,6 +21,32 @@ struct TestAllTypes {
dataField @13 : Data; dataField @13 : Data;
} }
struct TestSturdyRef {
hostId @0 :TestSturdyRefHostId;
objectId @1 :AnyPointer;
}
struct TestSturdyRefHostId {
host @0 :Text;
}
struct TestSturdyRefObjectId {
tag @0 :Tag;
enum Tag {
testInterface @0;
testExtends @1;
testPipeline @2;
testTailCallee @3;
testTailCaller @4;
testMoreStuff @5;
}
}
struct TestProvisionId {}
struct TestRecipientId {}
struct TestThirdPartyCapId {}
struct TestJoinResult {}
interface TestInterface { interface TestInterface {
foo @0 (i :UInt32, j :Bool) -> (x :Text); foo @0 (i :UInt32, j :Bool) -> (x :Text);
bar @1 () -> (); bar @1 () -> ();
@ -48,6 +74,76 @@ interface TestPipeline {
} }
} }
interface TestCallOrder {
getCallSequence @0 (expected: UInt32) -> (n: UInt32);
# First call returns 0, next returns 1, ...
#
# The input `expected` is ignored but useful for disambiguating debug logs.
}
interface TestTailCallee {
struct TailResult {
i @0 :UInt32;
t @1 :Text;
c @2 :TestCallOrder;
}
foo @0 (i :Int32, t :Text) -> TailResult;
}
interface TestTailCaller {
foo @0 (i :Int32, callee :TestTailCallee) -> TestTailCallee.TailResult;
}
interface TestHandle {}
interface TestMoreStuff extends(TestCallOrder) {
# Catch-all type that contains lots of testing methods.
callFoo @0 (cap :TestInterface) -> (s: Text);
# Call `cap.foo()`, check the result, and return "bar".
callFooWhenResolved @1 (cap :TestInterface) -> (s: Text);
# Like callFoo but waits for `cap` to resolve first.
neverReturn @2 (cap :TestInterface) -> (capCopy :TestInterface);
# Doesn't return. You should cancel it.
hold @3 (cap :TestInterface) -> ();
# Returns immediately but holds on to the capability.
callHeld @4 () -> (s: Text);
# Calls the capability previously held using `hold` (and keeps holding it).
getHeld @5 () -> (cap :TestInterface);
# Returns the capability previously held using `hold` (and keeps holding it).
echo @6 (cap :TestCallOrder) -> (cap :TestCallOrder);
# Just returns the input cap.
expectCancel @7 (cap :TestInterface) -> ();
# evalLater()-loops forever, holding `cap`. Must be canceled.
methodWithDefaults @8 (a :Text, b :UInt32 = 123, c :Text = "foo") -> (d :Text, e :Text = "bar");
methodWithNullDefault @12 (a :Text, b :TestInterface = null);
getHandle @9 () -> (handle :TestHandle);
# Get a new handle. Tests have an out-of-band way to check the current number of live handles, so
# this can be used to test garbage collection.
getNull @10 () -> (nullCap :TestMoreStuff);
# Always returns a null capability.
getEnormousString @11 () -> (str :Text);
# Attempts to return an 100MB string. Should always fail.
writeToFd @13 (fdCap1 :TestInterface, fdCap2 :TestInterface)
-> (fdCap3 :TestInterface, secondFdPresent :Bool);
# Expects fdCap1 and fdCap2 wrap socket file descriptors. Writes "foo" to the first and "bar" to
# the second. Also creates a socketpair, writes "baz" to one end, and returns the other end.
}
struct TestGenerics(Foo, Bar) { struct TestGenerics(Foo, Bar) {
foo @0 :Foo; foo @0 :Foo;
rev @1 :TestGenerics(Bar, Foo); rev @1 :TestGenerics(Bar, Foo);