make requests autoclosable and cleanup disconnection

This commit is contained in:
Vaci Koblizek 2020-11-12 22:13:48 +00:00
parent 4e9e7f4068
commit 69a045deec
10 changed files with 276 additions and 195 deletions

View file

@ -3,7 +3,6 @@ package org.capnproto;
import java.io.IOException; import java.io.IOException;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.io.StringWriter; import java.io.StringWriter;
import java.lang.ref.Reference;
import java.lang.ref.ReferenceQueue; import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference; import java.lang.ref.WeakReference;
import java.util.*; import java.util.*;
@ -35,6 +34,16 @@ final class RpcState<VatId> {
= RpcProtocol.CapDescriptor.factory.structSize().total() = RpcProtocol.CapDescriptor.factory.structSize().total()
+ RpcProtocol.PromisedAnswer.factory.structSize().total(); + RpcProtocol.PromisedAnswer.factory.structSize().total();
static class DisconnectInfo {
final CompletableFuture<java.lang.Void> shutdownPromise;
// Task which is working on sending an abort message and cleanly ending the connection.
DisconnectInfo(CompletableFuture<java.lang.Void> shutdownPromise) {
this.shutdownPromise = shutdownPromise;
}
}
private final class QuestionDisposer { private final class QuestionDisposer {
final int id; final int id;
@ -224,7 +233,6 @@ final class RpcState<VatId> {
final static class Embargo { final static class Embargo {
final int id; final int id;
final CompletableFuture<java.lang.Void> disembargo = new CompletableFuture<>(); final CompletableFuture<java.lang.Void> disembargo = new CompletableFuture<>();
Embargo(int id) { Embargo(int id) {
this.id = id; this.id = id;
} }
@ -263,7 +271,7 @@ final class RpcState<VatId> {
private final Map<ClientHook, Integer> exportsByCap = new HashMap<>(); private final Map<ClientHook, Integer> exportsByCap = new HashMap<>();
private final BootstrapFactory<VatId> bootstrapFactory; private final BootstrapFactory<VatId> bootstrapFactory;
private final VatNetwork.Connection<VatId> connection; private final VatNetwork.Connection<VatId> connection;
private final CompletableFuture<java.lang.Void> onDisconnect; private final CompletableFuture<DisconnectInfo> disconnectFulfiller;
private Throwable disconnected = null; private Throwable disconnected = null;
private CompletableFuture<java.lang.Void> messageReady = CompletableFuture.completedFuture(null); private CompletableFuture<java.lang.Void> messageReady = CompletableFuture.completedFuture(null);
private final CompletableFuture<java.lang.Void> messageLoop = new CompletableFuture<>(); private final CompletableFuture<java.lang.Void> messageLoop = new CompletableFuture<>();
@ -273,10 +281,10 @@ final class RpcState<VatId> {
RpcState(BootstrapFactory<VatId> bootstrapFactory, RpcState(BootstrapFactory<VatId> bootstrapFactory,
VatNetwork.Connection<VatId> connection, VatNetwork.Connection<VatId> connection,
CompletableFuture<java.lang.Void> onDisconnect) { CompletableFuture<DisconnectInfo> disconnectFulfiller) {
this.bootstrapFactory = bootstrapFactory; this.bootstrapFactory = bootstrapFactory;
this.connection = connection; this.connection = connection;
this.onDisconnect = onDisconnect; this.disconnectFulfiller = disconnectFulfiller;
startMessageLoop(); startMessageLoop();
} }
@ -284,13 +292,10 @@ final class RpcState<VatId> {
return this.messageLoop; return this.messageLoop;
} }
public CompletableFuture<java.lang.Void> onDisconnect() { void disconnect(Throwable exc) {
return this.messageLoop;
}
CompletableFuture<java.lang.Void> disconnect(Throwable exc) {
if (isDisconnected()) { if (isDisconnected()) {
return CompletableFuture.failedFuture(this.disconnected); // Already disconnected.
return;
} }
var networkExc = RpcException.disconnected(exc.getMessage()); var networkExc = RpcException.disconnected(exc.getMessage());
@ -334,6 +339,7 @@ final class RpcState<VatId> {
} }
} }
// Send an abort message, but ignore failure.
try { try {
int sizeHint = messageSizeHint() + exceptionSizeHint(exc); int sizeHint = messageSizeHint() + exceptionSizeHint(exc);
var message = this.connection.newOutgoingMessage(sizeHint); var message = this.connection.newOutgoingMessage(sizeHint);
@ -344,25 +350,31 @@ final class RpcState<VatId> {
catch (Exception ignored) { catch (Exception ignored) {
} }
var onShutdown = this.connection.shutdown().handle((x, ioExc) -> { var shutdownPromise = this.connection.shutdown()
if (ioExc == null) { .exceptionallyCompose(ioExc -> {
return CompletableFuture.completedFuture(null);
}
// TODO IOException?
assert !(ioExc instanceof IOException); assert !(ioExc instanceof IOException);
if (ioExc instanceof RpcException) { if (ioExc instanceof RpcException) {
var rpcExc = (RpcException)exc; var rpcExc = (RpcException)exc;
// Don't report disconnects as an error
if (rpcExc.getType() == RpcException.Type.DISCONNECTED) { if (rpcExc.getType() == RpcException.Type.DISCONNECTED) {
return CompletableFuture.completedFuture(null); return CompletableFuture.completedFuture(null);
} }
} }
return CompletableFuture.failedFuture(ioExc); return CompletableFuture.failedFuture(ioExc);
}); });
this.disconnected = networkExc; this.disconnected = networkExc;
return onShutdown.thenCompose(x -> CompletableFuture.failedFuture(networkExc)); this.disconnectFulfiller.complete(new DisconnectInfo(shutdownPromise));
for (var pipeline: pipelinesToRelease) {
if (pipeline instanceof RpcState<?>.RpcPipeline) {
((RpcPipeline) pipeline).redirectLater.completeExceptionally(networkExc);
}
}
} }
final boolean isDisconnected() { final boolean isDisconnected() {
@ -389,12 +401,7 @@ final class RpcState<VatId> {
ClientHook restore() { ClientHook restore() {
var question = questions.next(); var question = questions.next();
question.setAwaitingReturn(true); question.setAwaitingReturn(true);
// Run the message loop until the boostrap promise is resolved.
var promise = new CompletableFuture<RpcResponse>(); var promise = new CompletableFuture<RpcResponse>();
var loop = CompletableFuture.anyOf(
getMessageLoop(), promise).thenCompose(x -> promise);
int sizeHint = messageSizeHint(RpcProtocol.Bootstrap.factory); int sizeHint = messageSizeHint(RpcProtocol.Bootstrap.factory);
var message = connection.newOutgoingMessage(sizeHint); var message = connection.newOutgoingMessage(sizeHint);
var builder = message.getBody().initAs(RpcProtocol.Message.factory).initBootstrap(); var builder = message.getBody().initAs(RpcProtocol.Message.factory).initBootstrap();
@ -413,6 +420,7 @@ final class RpcState<VatId> {
var messageReader = this.connection.receiveIncomingMessage() var messageReader = this.connection.receiveIncomingMessage()
.thenAccept(message -> { .thenAccept(message -> {
if (message == null) { if (message == null) {
this.disconnect(RpcException.disconnected("Peer disconnected"));
this.messageLoop.complete(null); this.messageLoop.complete(null);
return; return;
} }
@ -423,11 +431,12 @@ final class RpcState<VatId> {
// or internal RpcState is bad. // or internal RpcState is bad.
this.disconnect(rpcExc); this.disconnect(rpcExc);
} }
this.cleanupImports();
this.cleanupQuestions();
}); });
messageReader.thenRunAsync(this::startMessageLoop); messageReader.thenRunAsync(this::startMessageLoop).exceptionallyCompose(exc -> {
assert exc == null: "Exception in startMessageLoop!";
return CompletableFuture.failedFuture(exc);
});
} }
private void handleMessage(IncomingRpcMessage message) throws RpcException { private void handleMessage(IncomingRpcMessage message) throws RpcException {
@ -470,6 +479,9 @@ final class RpcState<VatId> {
} }
break; break;
} }
this.cleanupImports();
this.cleanupQuestions();
} }
void handleUnimplemented(RpcProtocol.Message.Reader message) { void handleUnimplemented(RpcProtocol.Message.Reader message) {
@ -1427,7 +1439,6 @@ final class RpcState<VatId> {
this.responseSent = false; this.responseSent = false;
sendErrorReturn(exc); sendErrorReturn(exc);
} }
cleanupAnswerTable(exports); cleanupAnswerTable(exports);
} }
@ -1512,6 +1523,7 @@ final class RpcState<VatId> {
RpcPipeline(Question question, RpcPipeline(Question question,
CompletableFuture<RpcResponse> redirectLater) { CompletableFuture<RpcResponse> redirectLater) {
this.question = question; this.question = question;
assert redirectLater != null;
this.redirectLater = redirectLater; this.redirectLater = redirectLater;
} }
@ -1542,6 +1554,11 @@ final class RpcState<VatId> {
return new PromiseClient(pipelineClient, resolutionPromise, null); return new PromiseClient(pipelineClient, resolutionPromise, null);
}); });
} }
@Override
public void close() {
this.question.finish();
}
} }
abstract class RpcClient implements ClientHook { abstract class RpcClient implements ClientHook {
@ -1787,11 +1804,11 @@ final class RpcState<VatId> {
this.cap = initial; this.cap = initial;
this.importId = importId; this.importId = importId;
eventual.whenComplete((resolution, exc) -> { eventual.whenComplete((resolution, exc) -> {
if (exc != null) { if (exc == null) {
resolve(Capability.newBrokenCap(exc)); resolve(resolution);
} }
else { else {
resolve(resolution); resolve(Capability.newBrokenCap(exc));
} }
}); });
} }
@ -1842,6 +1859,10 @@ final class RpcState<VatId> {
// TODO Flow control // TODO Flow control
if (resolutionType == ResolutionType.REFLECTED && receivedCall && !isDisconnected()) { if (resolutionType == ResolutionType.REFLECTED && receivedCall && !isDisconnected()) {
// The new capability is hosted locally, not on the remote machine. And, we had made calls
// to the promise. We need to make sure those calls echo back to us before we allow new
// calls to go directly to the local capability, so we need to set a local embargo and send
// a `Disembargo` to echo through the peer.
int sizeHint = messageSizeHint(RpcProtocol.Disembargo.factory); int sizeHint = messageSizeHint(RpcProtocol.Disembargo.factory);
var message = connection.newOutgoingMessage(sizeHint); var message = connection.newOutgoingMessage(sizeHint);
var disembargo = message.getBody().initAs(RpcProtocol.Message.factory).initDisembargo(); var disembargo = message.getBody().initAs(RpcProtocol.Message.factory).initDisembargo();
@ -1852,7 +1873,8 @@ final class RpcState<VatId> {
disembargo.getContext().setSenderLoopback(embargo.id); disembargo.getContext().setSenderLoopback(embargo.id);
final ClientHook finalReplacement = replacement; final ClientHook finalReplacement = replacement;
var embargoPromise = embargo.disembargo.thenApply(x -> finalReplacement); var embargoPromise = embargo.disembargo.thenApply(
void_ -> finalReplacement);
replacement = Capability.newLocalPromiseClient(embargoPromise); replacement = Capability.newLocalPromiseClient(embargoPromise);
message.send(); message.send();
} }

View file

@ -1,5 +1,6 @@
package org.capnproto; package org.capnproto;
import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -68,14 +69,21 @@ public class RpcSystem<VatId extends StructReader> {
} }
RpcState<VatId> getConnectionState(VatNetwork.Connection<VatId> connection) { RpcState<VatId> getConnectionState(VatNetwork.Connection<VatId> connection) {
var state = this.connections.get(connection);
var onDisconnect = new CompletableFuture<VatNetwork.Connection<VatId>>() if (state == null) {
.thenAccept(lostConnection -> { var onDisconnect = new CompletableFuture<RpcState.DisconnectInfo>()
this.connections.remove(lostConnection); .whenComplete((info, exc) -> {
this.connections.remove(connection);
try {
connection.close();
} catch (IOException ignored) {
}
}); });
return connections.computeIfAbsent(connection, key -> state = new RpcState<>(this.bootstrapFactory, connection, onDisconnect);
new RpcState<VatId>(this.bootstrapFactory, connection, onDisconnect)); this.connections.put(connection, state);
}
return state;
} }
public void accept(VatNetwork.Connection<VatId> connection) { public void accept(VatNetwork.Connection<VatId> connection) {

View file

@ -20,7 +20,7 @@ public class TwoPartyClient {
Capability.Client bootstrapInterface, Capability.Client bootstrapInterface,
RpcTwoPartyProtocol.Side side) { RpcTwoPartyProtocol.Side side) {
this.network = new TwoPartyVatNetwork(channel, side); this.network = new TwoPartyVatNetwork(channel, side);
this.rpcSystem = new RpcSystem<RpcTwoPartyProtocol.VatId.Reader>(network, bootstrapInterface); this.rpcSystem = new RpcSystem<>(network, bootstrapInterface);
} }
public Capability.Client bootstrap() { public Capability.Client bootstrap() {
@ -31,12 +31,4 @@ public class TwoPartyClient {
: RpcTwoPartyProtocol.Side.CLIENT); : RpcTwoPartyProtocol.Side.CLIENT);
return rpcSystem.bootstrap(vatId.asReader()); return rpcSystem.bootstrap(vatId.asReader());
} }
public TwoPartyVatNetwork getNetwork() {
return this.network;
}
public CompletableFuture<java.lang.Void> onDisconnect() {
return this.network.onDisconnect();
}
} }

View file

@ -1,5 +1,6 @@
package org.capnproto; package org.capnproto;
import java.io.IOException;
import java.nio.channels.AsynchronousSocketChannel; import java.nio.channels.AsynchronousSocketChannel;
import java.util.List; import java.util.List;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -18,7 +19,7 @@ public class TwoPartyVatNetwork
} }
private CompletableFuture<java.lang.Void> previousWrite = CompletableFuture.completedFuture(null); private CompletableFuture<java.lang.Void> previousWrite = CompletableFuture.completedFuture(null);
private final CompletableFuture<java.lang.Void> peerDisconnected = new CompletableFuture<>(); private final CompletableFuture<java.lang.Void> disconnectPromise = new CompletableFuture<>();
private final AsynchronousSocketChannel channel; private final AsynchronousSocketChannel channel;
private final RpcTwoPartyProtocol.Side side; private final RpcTwoPartyProtocol.Side side;
private final MessageBuilder peerVatId = new MessageBuilder(4); private final MessageBuilder peerVatId = new MessageBuilder(4);
@ -34,6 +35,12 @@ public class TwoPartyVatNetwork
: RpcTwoPartyProtocol.Side.CLIENT); : RpcTwoPartyProtocol.Side.CLIENT);
} }
@Override
public void close() throws IOException {
this.channel.close();
this.disconnectPromise.complete(null);
}
public RpcTwoPartyProtocol.Side getSide() { public RpcTwoPartyProtocol.Side getSide() {
return side; return side;
} }
@ -46,6 +53,10 @@ public class TwoPartyVatNetwork
return this; return this;
} }
public CompletableFuture<java.lang.Void> onDisconnect() {
return this.disconnectPromise.copy();
}
@Override @Override
public Connection<RpcTwoPartyProtocol.VatId.Reader> connect(RpcTwoPartyProtocol.VatId.Reader vatId) { public Connection<RpcTwoPartyProtocol.VatId.Reader> connect(RpcTwoPartyProtocol.VatId.Reader vatId) {
return vatId.getSide() != side return vatId.getSide() != side
@ -59,7 +70,7 @@ public class TwoPartyVatNetwork
return CompletableFuture.completedFuture(this.asConnection()); return CompletableFuture.completedFuture(this.asConnection());
} }
else { else {
// never /home/vaci/g/capnproto-java/compilercompletes // never completes
return new CompletableFuture<>(); return new CompletableFuture<>();
} }
} }
@ -97,20 +108,20 @@ public class TwoPartyVatNetwork
return message; return message;
} }
@Override
public CompletableFuture<java.lang.Void> onDisconnect() {
return this.peerDisconnected.copy();
}
@Override @Override
public CompletableFuture<java.lang.Void> shutdown() { public CompletableFuture<java.lang.Void> shutdown() {
return this.previousWrite.whenComplete((x, exc) -> { assert this.previousWrite != null: "Already shut down";
var result = this.previousWrite.thenRun(() -> {
try { try {
this.channel.shutdownOutput(); this.channel.shutdownOutput();
} }
catch (Exception ioExc) { catch (Exception ioExc) {
} }
}); });
this.previousWrite = null;
return result;
} }
final class OutgoingMessage implements OutgoingRpcMessage { final class OutgoingMessage implements OutgoingRpcMessage {

View file

@ -1,24 +1,22 @@
package org.capnproto; package org.capnproto;
import java.io.IOException;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
public interface VatNetwork<VatId> public interface VatNetwork<VatId>
{ {
interface Connection<VatId> { interface Connection<VatId> extends AutoCloseable {
default OutgoingRpcMessage newOutgoingMessage() { default OutgoingRpcMessage newOutgoingMessage() {
return newOutgoingMessage(0); return newOutgoingMessage(0);
} }
OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize); OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize);
CompletableFuture<IncomingRpcMessage> receiveIncomingMessage(); CompletableFuture<IncomingRpcMessage> receiveIncomingMessage();
CompletableFuture<java.lang.Void> onDisconnect();
CompletableFuture<java.lang.Void> shutdown(); CompletableFuture<java.lang.Void> shutdown();
VatId getPeerVatId(); VatId getPeerVatId();
void close() throws IOException;
} }
CompletableFuture<Connection<VatId>> baseAccept(); CompletableFuture<Connection<VatId>> baseAccept();
//FromPointerReader<VatId> getVatIdFactory();
Connection<VatId> connect(VatId hostId); Connection<VatId> connect(VatId hostId);
} }

View file

@ -5,30 +5,17 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import java.io.IOException;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
public class RpcStateTest { public class RpcStateTest {
class TestMessage implements IncomingRpcMessage {
MessageBuilder builder = new MessageBuilder();
@Override
public AnyPointer.Reader getBody() {
return builder.getRoot(AnyPointer.factory).asReader();
}
}
class TestConnection implements VatNetwork.Connection<RpcTwoPartyProtocol.VatId.Reader> { class TestConnection implements VatNetwork.Connection<RpcTwoPartyProtocol.VatId.Reader> {
private CompletableFuture<IncomingRpcMessage> nextIncomingMessage = new CompletableFuture<>(); private CompletableFuture<IncomingRpcMessage> nextIncomingMessage = new CompletableFuture<>();
private final CompletableFuture<java.lang.Void> disconnect = new CompletableFuture<>(); private final CompletableFuture<RpcState.DisconnectInfo> disconnect = new CompletableFuture<>();
public void setNextIncomingMessage(IncomingRpcMessage message) {
this.nextIncomingMessage.complete(message);
}
@Override @Override
public OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize) { public OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize) {
@ -43,6 +30,19 @@ public class RpcStateTest {
@Override @Override
public void send() { public void send() {
sent.add(this); sent.add(this);
var msg = new IncomingRpcMessage() {
@Override
public AnyPointer.Reader getBody() {
return message.getRoot(AnyPointer.factory).asReader();
}
};
if (nextIncomingMessage.isDone()) {
nextIncomingMessage = CompletableFuture.completedFuture(msg);
}
else {
nextIncomingMessage.complete(msg);
}
} }
@Override @Override
@ -57,21 +57,20 @@ public class RpcStateTest {
return this.nextIncomingMessage; return this.nextIncomingMessage;
} }
@Override
public CompletableFuture<java.lang.Void> onDisconnect() {
return this.disconnect.copy();
}
@Override @Override
public CompletableFuture<java.lang.Void> shutdown() { public CompletableFuture<java.lang.Void> shutdown() {
this.disconnect.complete(null); this.disconnect.complete(null);
return this.disconnect.copy(); return this.disconnect.thenRun(() -> {});
} }
@Override @Override
public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() { public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() {
return null; return null;
} }
@Override
public void close() {
}
} }
TestConnection connection; TestConnection connection;
@ -80,7 +79,7 @@ public class RpcStateTest {
final Queue<OutgoingRpcMessage> sent = new ArrayDeque<>(); final Queue<OutgoingRpcMessage> sent = new ArrayDeque<>();
@Before @Before
public void setUp() throws Exception { public void setUp() {
this.connection = new TestConnection(); this.connection = new TestConnection();
this.bootstrapInterface = new Capability.Client(Capability.newNullCap()); this.bootstrapInterface = new Capability.Client(Capability.newNullCap());
var bootstrapFactory = new BootstrapFactory<RpcTwoPartyProtocol.VatId.Reader>() { var bootstrapFactory = new BootstrapFactory<RpcTwoPartyProtocol.VatId.Reader>() {
@ -95,45 +94,50 @@ public class RpcStateTest {
} }
}; };
this.rpc = new RpcState<RpcTwoPartyProtocol.VatId.Reader>(bootstrapFactory, connection, connection.disconnect); this.rpc = new RpcState<>(bootstrapFactory, connection, connection.disconnect);
} }
@After @After
public void tearDown() throws Exception { public void tearDown() {
this.connection = null; this.connection = null;
this.rpc = null; this.rpc = null;
this.sent.clear(); this.sent.clear();
} }
/*
@Test @Test
public void handleUnimplemented() throws RpcException { public void handleUnimplemented() {
var msg = new TestMessage(); var msg = this.connection.newOutgoingMessage(0);
msg.builder.getRoot(RpcProtocol.Message.factory).initUnimplemented(); var root = msg.getBody().initAs(RpcProtocol.Message.factory).initUnimplemented();
this.connection.setNextIncomingMessage(msg); var resolve = root.initResolve();
RpcState.FromException(new Exception("foo"), resolve.initException());
msg.send();
Assert.assertFalse(sent.isEmpty());
} }
*/
@Test @Test
public void handleAbort() { public void handleAbort() {
var msg = new TestMessage(); var msg = this.connection.newOutgoingMessage(0);
var builder = msg.builder.getRoot(RpcProtocol.Message.factory); var builder = msg.getBody().initAs(RpcProtocol.Message.factory);
RpcState.FromException(RpcException.failed("Test abort"), builder.initAbort()); RpcState.FromException(RpcException.failed("Test abort"), builder.initAbort());
this.connection.setNextIncomingMessage(msg); msg.send();
//Assert.assertThrows(RpcException.class, () -> rpc.handleMessage(msg));
} }
@Test @Test
public void handleBootstrap() throws RpcException { public void handleBootstrap() {
var msg = new TestMessage(); var msg = this.connection.newOutgoingMessage(0);
var bootstrap = msg.builder.getRoot(RpcProtocol.Message.factory).initBootstrap(); var bootstrap = msg.getBody().initAs(RpcProtocol.Message.factory).initBootstrap();
bootstrap.setQuestionId(0); bootstrap.setQuestionId(0);
this.connection.setNextIncomingMessage(msg); msg.send();
Assert.assertFalse(sent.isEmpty()); Assert.assertEquals(2, sent.size());
var reply = sent.remove();
sent.remove(); // bootstrap
var reply = sent.remove(); // return
var rpcMsg = reply.getBody().getAs(RpcProtocol.Message.factory); var rpcMsg = reply.getBody().getAs(RpcProtocol.Message.factory);
Assert.assertEquals(rpcMsg.which(), RpcProtocol.Message.Which.RETURN); Assert.assertEquals(RpcProtocol.Message.Which.RETURN, rpcMsg.which());
var ret = rpcMsg.getReturn(); var ret = rpcMsg.getReturn();
Assert.assertEquals(ret.getAnswerId(), 0); Assert.assertEquals(ret.getAnswerId(), 0);
Assert.assertEquals(ret.which(), RpcProtocol.Return.Which.RESULTS); Assert.assertEquals(RpcProtocol.Return.Which.RESULTS, ret.which());
var results = ret.getResults(); var results = ret.getResults();
Assert.assertEquals(results.getCapTable().size(), 1); // got a capability! Assert.assertEquals(results.getCapTable().size(), 1); // got a capability!
Assert.assertTrue(results.hasContent()); Assert.assertTrue(results.hasContent());

View file

@ -25,13 +25,10 @@ import org.capnproto.rpctest.Test;
import org.junit.Assert; import org.junit.Assert;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -156,11 +153,6 @@ public class RpcTest {
} }
} }
@Override
public CompletableFuture<java.lang.Void> onDisconnect() {
return null;
}
@Override @Override
public CompletableFuture<java.lang.Void> shutdown() { public CompletableFuture<java.lang.Void> shutdown() {
if (this.partner == null) { if (this.partner == null) {
@ -174,6 +166,10 @@ public class RpcTest {
public Test.TestSturdyRef.Reader getPeerVatId() { public Test.TestSturdyRef.Reader getPeerVatId() {
return this.peerId; return this.peerId;
} }
@Override
public void close() {
}
} }
final TestNetwork network; final TestNetwork network;
@ -430,6 +426,7 @@ public class RpcTest {
Assert.assertEquals(456, response.getI()); Assert.assertEquals(456, response.getI());
var dependentCall1 = promise.getC().getCallSequenceRequest().send(); var dependentCall1 = promise.getC().getCallSequenceRequest().send();
Assert.assertEquals(0, dependentCall0.join().getN()); Assert.assertEquals(0, dependentCall0.join().getN());
Assert.assertEquals(1, dependentCall1.join().getN()); Assert.assertEquals(1, dependentCall1.join().getN());

View file

@ -1,60 +1,15 @@
package org.capnproto; package org.capnproto;
/* import org.capnproto.rpctest.*;
import org.capnproto.demo.Demo;
import org.junit.After; import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.function.ThrowingRunnable;
import java.io.IOException; import java.io.IOException;
import java.nio.channels.AsynchronousServerSocketChannel; import java.nio.channels.AsynchronousServerSocketChannel;
import java.nio.channels.AsynchronousSocketChannel; import java.nio.channels.AsynchronousSocketChannel;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
class TestCap0Impl extends Demo.TestCap0.Server {
final Demo.TestCap1.Client testCap1a = new Demo.TestCap1.Client(new TestCap1Impl());
final Demo.TestCap1.Client testCap1b = new Demo.TestCap1.Client(new TestCap1Impl());
public CompletableFuture<java.lang.Void> testMethod0(CallContext<Demo.TestParams0.Reader, Demo.TestResults0.Builder> ctx) {
var params = ctx.getParams();
var results = ctx.getResults();
results.setResult0(params.getParam0());
ctx.releaseParams();
return CompletableFuture.completedFuture(null);
}
public CompletableFuture<java.lang.Void> testMethod1(CallContext<Demo.TestParams1.Reader, Demo.TestResults1.Builder> ctx) {
var params = ctx.getParams();
var results = ctx.getResults();
var res0 = results.getResult0();
res0.setAs(Demo.TestCap1.factory, testCap1a);
var res1 = results.getResult1();
res1.setAs(Demo.TestCap1.factory, testCap1b);
var res2 = results.getResult2();
res2.setAs(Demo.TestCap1.factory, testCap1b);
return CompletableFuture.completedFuture(null);
}
}
class TestCap1Impl extends Demo.TestCap1.Server {
}
class Tap implements org.capnproto.TwoPartyVatNetwork.MessageTap {
final RpcDumper dumper = new RpcDumper();
@Override
public void incoming(IncomingRpcMessage message, RpcTwoPartyProtocol.Side side) {
var text = this.dumper.dump(message.getBody().getAs(RpcProtocol.Message.factory), side);
if (text.length() > 0) {
System.out.println(text);
}
}
}
public class TwoPartyTest { public class TwoPartyTest {
@ -73,7 +28,8 @@ public class TwoPartyTest {
return thread; return thread;
} }
AsynchronousServerSocketChannel serverSocket; AsynchronousServerSocketChannel serverAcceptSocket;
AsynchronousSocketChannel serverSocket;
AsynchronousSocketChannel clientSocket; AsynchronousSocketChannel clientSocket;
TwoPartyClient client; TwoPartyClient client;
org.capnproto.TwoPartyVatNetwork serverNetwork; org.capnproto.TwoPartyVatNetwork serverNetwork;
@ -81,17 +37,17 @@ public class TwoPartyTest {
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
this.serverSocket = AsynchronousServerSocketChannel.open(); this.serverAcceptSocket = AsynchronousServerSocketChannel.open();
this.serverSocket.bind(null); this.serverAcceptSocket.bind(null);
this.clientSocket = AsynchronousSocketChannel.open(); this.clientSocket = AsynchronousSocketChannel.open();
this.clientSocket.connect(this.serverSocket.getLocalAddress()).get(); this.clientSocket.connect(this.serverAcceptSocket.getLocalAddress()).get();
this.client = new TwoPartyClient(clientSocket); this.client = new TwoPartyClient(clientSocket);
this.client.getNetwork().setTap(new Tap()); //this.client.getNetwork().setTap(new Tap());
var socket = serverSocket.accept().get(); this.serverSocket = serverAcceptSocket.accept().get();
this.serverNetwork = new org.capnproto.TwoPartyVatNetwork(socket, RpcTwoPartyProtocol.Side.SERVER); this.serverNetwork = new org.capnproto.TwoPartyVatNetwork(this.serverSocket, RpcTwoPartyProtocol.Side.SERVER);
this.serverNetwork.setTap(new Tap()); //this.serverNetwork.setTap(new Tap());
//this.serverNetwork.dumper.addSchema(Demo.TestCap1); //this.serverNetwork.dumper.addSchema(Demo.TestCap1);
this.serverThread = runServer(this.serverNetwork); this.serverThread = runServer(this.serverNetwork);
} }
@ -100,36 +56,128 @@ public class TwoPartyTest {
public void tearDown() throws Exception { public void tearDown() throws Exception {
this.clientSocket.close(); this.clientSocket.close();
this.serverSocket.close(); this.serverSocket.close();
this.serverAcceptSocket.close();
this.serverThread.join(); this.serverThread.join();
this.client = null; this.client = null;
} }
@Test @org.junit.Test
public void testNullCap() throws ExecutionException, InterruptedException { public void testNullCap() throws ExecutionException, InterruptedException {
var server = new RpcSystem<>(this.serverNetwork, new Capability.Client()); var server = new RpcSystem<>(this.serverNetwork, new Capability.Client());
var cap = this.client.bootstrap(); var cap = this.client.bootstrap();
var resolved = cap.whenResolved().toCompletableFuture(); var resolved = cap.whenResolved();
resolved.get(); resolved.get();
} }
@Test @org.junit.Test
public void testBasic() throws ExecutionException, InterruptedException, IOException { public void testBasic() throws InterruptedException, IOException {
var server = new RpcSystem<>(this.serverNetwork, new TestCap0Impl());
var demo = new Demo.TestCap0.Client(this.client.bootstrap()); var callCount = new Counter();
var request = demo.testMethod0Request(); var server = new RpcSystem<>(this.serverNetwork, new RpcTestUtil.TestInterfaceImpl(callCount));
var params = request.getParams();
params.setParam0(4321); var client = new Test.TestInterface.Client(this.client.bootstrap());
var response = request.send(); var request1 = client.fooRequest();
response.get(); request1.getParams().setI(123);
Assert.assertTrue(response.isDone()); request1.getParams().setJ(true);
var results = response.get();
Assert.assertEquals(params.getParam0(), results.getResult0()); var promise1 = request1.send();
var request2 = client.bazRequest();
RpcTestUtil.initTestMessage(request2.getParams().initS());
var promise2 = request2.send();
boolean barFailed = false;
var request3 = client.barRequest();
var promise3 = request3.send()
.thenAccept(results -> Assert.fail("Expected bar() to fail"))
.exceptionally(exc -> null);
var response1 = promise1.join();
Assert.assertEquals("foo", response1.getX().toString());
promise2.join();
promise3.join();
Assert.assertEquals(2, callCount.value());
this.clientSocket.shutdownOutput(); this.clientSocket.shutdownOutput();
serverThread.join(); serverThread.join();
} }
@Test @org.junit.Test
public void testDisconnect() throws IOException {
this.serverSocket.shutdownOutput();
this.serverNetwork.close();
this.serverNetwork.onDisconnect().join();
}
@org.junit.Test
public void testPipelining() throws IOException {
var callCount = new Counter();
var chainedCallCount = new Counter();
var server = new RpcSystem<>(this.serverNetwork, new RpcTestUtil.TestPipelineImpl(callCount));
var client = new Test.TestPipeline.Client(this.client.bootstrap());
{
var request = client.getCapRequest();
request.getParams().setN(234);
request.getParams().setInCap(new RpcTestUtil.TestInterfaceImpl(chainedCallCount));
var promise = request.send();
var pipelineRequest = promise.getOutBox().getCap().fooRequest();
pipelineRequest.getParams().setI(321);
var pipelinePromise = pipelineRequest.send();
var pipelineRequest2 = new Test.TestExtends.Client(promise.getOutBox().getCap()).graultRequest();
var pipelinePromise2 = pipelineRequest2.send();
promise = null;
//Assert.assertEquals(0, chainedCallCount.value());
var response = pipelinePromise.join();
Assert.assertEquals("bar", response.getX().toString());
var response2 = pipelinePromise2.join();
RpcTestUtil.checkTestMessage(response2);
Assert.assertEquals(1, chainedCallCount.value());
}
/*
// disconnect the server
//this.serverSocket.shutdownOutput();
this.serverNetwork.close();
this.serverNetwork.onDisconnect().join();
{
// Use the now-broken capability.
var request = client.getCapRequest();
request.getParams().setN(234);
request.getParams().setInCap(new RpcTestUtil.TestInterfaceImpl(chainedCallCount));
var promise = request.send();
var pipelineRequest = promise.getOutBox().getCap().fooRequest();
pipelineRequest.getParams().setI(321);
var pipelinePromise = pipelineRequest.send();
var pipelineRequest2 = new Test.TestExtends.Client(promise.getOutBox().getCap()).graultRequest();
var pipelinePromise2 = pipelineRequest2.send();
Assert.assertThrows(Exception.class, () -> pipelinePromise.join());
Assert.assertThrows(Exception.class, () -> pipelinePromise2.join());
Assert.assertEquals(3, callCount.value());
Assert.assertEquals(1, chainedCallCount.value());
}
*/
}
/*
@org.junit.Test
public void testBasicCleanup() throws ExecutionException, InterruptedException, TimeoutException { public void testBasicCleanup() throws ExecutionException, InterruptedException, TimeoutException {
var server = new RpcSystem<>(this.serverNetwork, new TestCap0Impl()); var server = new RpcSystem<>(this.serverNetwork, new TestCap0Impl());
var demo = new Demo.TestCap0.Client(this.client.bootstrap()); var demo = new Demo.TestCap0.Client(this.client.bootstrap());
@ -145,7 +193,7 @@ public class TwoPartyTest {
demo = null; demo = null;
} }
@Test @org.junit.Test
public void testShutdown() throws InterruptedException, IOException { public void testShutdown() throws InterruptedException, IOException {
var server = new RpcSystem<>(this.serverNetwork, new TestCap0Impl()); var server = new RpcSystem<>(this.serverNetwork, new TestCap0Impl());
var demo = new Demo.TestCap0.Client(this.client.bootstrap()); var demo = new Demo.TestCap0.Client(this.client.bootstrap());
@ -153,7 +201,7 @@ public class TwoPartyTest {
serverThread.join(); serverThread.join();
} }
@Test @org.junit.Test
public void testCallThrows() throws ExecutionException, InterruptedException { public void testCallThrows() throws ExecutionException, InterruptedException {
var impl = new Demo.TestCap0.Server() { var impl = new Demo.TestCap0.Server() {
public CompletableFuture<java.lang.Void> testMethod0(CallContext<Demo.TestParams0.Reader, Demo.TestResults0.Builder> ctx) { public CompletableFuture<java.lang.Void> testMethod0(CallContext<Demo.TestParams0.Reader, Demo.TestResults0.Builder> ctx) {
@ -185,7 +233,7 @@ public class TwoPartyTest {
} }
} }
@Test @org.junit.Test
public void testReturnCap() throws ExecutionException, InterruptedException { public void testReturnCap() throws ExecutionException, InterruptedException {
// send a capability back from the server to the client // send a capability back from the server to the client
var capServer = new TestCap0Impl(); var capServer = new TestCap0Impl();
@ -204,5 +252,5 @@ public class TwoPartyTest {
var cap2 = results.getResult2(); var cap2 = results.getResult2();
Assert.assertFalse(cap2.isNull()); Assert.assertFalse(cap2.isNull());
} }
*/
} }
*/

View file

@ -150,7 +150,7 @@ public final class Capability {
} }
private final class LocalClient implements ClientHook { private final class LocalClient implements ClientHook {
private final CompletableFuture<java.lang.Void> resolveTask; private CompletableFuture<java.lang.Void> resolveTask;
private ClientHook resolved; private ClientHook resolved;
private boolean blocked = false; private boolean blocked = false;
private final CapabilityServerSetBase capServerSet; private final CapabilityServerSetBase capServerSet;
@ -162,11 +162,16 @@ public final class Capability {
LocalClient(CapabilityServerSetBase capServerSet) { LocalClient(CapabilityServerSetBase capServerSet) {
Server.this.hook = this; Server.this.hook = this;
this.capServerSet = capServerSet; this.capServerSet = capServerSet;
startResolveTask();
}
var resolver = shortenPath(); private void startResolveTask() {
this.resolveTask = resolver != null var resolveTask = shortenPath();
? resolver.thenAccept(client -> this.resolved = client.getHook()) if (resolveTask != null) {
: null; this.resolveTask = resolveTask.thenAccept(cap -> {
this.resolved = cap.getHook();
});
}
} }
@Override @Override
@ -209,6 +214,7 @@ public final class Capability {
@Override @Override
public CompletableFuture<ClientHook> whenMoreResolved() { public CompletableFuture<ClientHook> whenMoreResolved() {
if (this.resolved != null) { if (this.resolved != null) {
System.out.println("Local client resolved! " + this.toString());
return CompletableFuture.completedFuture(this.resolved); return CompletableFuture.completedFuture(this.resolved);
} }
else if (this.resolveTask != null) { else if (this.resolveTask != null) {

View file

@ -5,12 +5,7 @@ public interface PipelineHook extends AutoCloseable {
ClientHook getPipelinedCap(PipelineOp[] ops); ClientHook getPipelinedCap(PipelineOp[] ops);
static PipelineHook newBrokenPipeline(Throwable exc) { static PipelineHook newBrokenPipeline(Throwable exc) {
return new PipelineHook() { return ops -> Capability.newBrokenCap(exc);
@Override
public ClientHook getPipelinedCap(PipelineOp[] ops) {
return Capability.newBrokenCap(exc);
}
};
} }
@Override @Override